Skip to content

Commit

Permalink
small changes from the viz_rewrite branch [pr] (tinygrad#6907)
Browse files Browse the repository at this point in the history
* simpler replace

* dont show shapetracker consts

* changed_nodes shouldn't exist for the first sink
  • Loading branch information
Qazalin authored Oct 6, 2024
1 parent 16c1fa4 commit b066ef2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion viz/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@
ret = await (await fetch(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`)).json();
cache[cacheKey] = ret;
}
renderGraph(ret.graphs[currentRewrite], ret.changed_nodes[currentRewrite]);
renderGraph(ret.graphs[currentRewrite], currentRewrite == 0 ? [] : ret.changed_nodes[currentRewrite-1]);
// ***** RHS metadata
const metadata = document.querySelector(".container.metadata");
metadata.innerHTML = "";
Expand Down
14 changes: 7 additions & 7 deletions viz/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]:
uops: List[UOp] = [sink]
diffs: List[List[str]] = []
changed_nodes: List[List[int]] = [[]]
seen_replaces: Dict[bytes, UOp] = {}
changed_nodes: List[List[int]] = []
seen_replaces: Dict[UOp, UOp] = {}
for i, (first, rewritten, _) in enumerate(rewrites):
# first, rewrite this UOp with the current rewrite + all the seen rewrites before this
seen_replaces[first.key] = rewritten
seen_replaces[first] = rewritten
new_sink = replace_uop(uops[-1], {**seen_replaces})
# sanity check
assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}"
Expand All @@ -41,15 +41,15 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]:
graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff"))
return graph

def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp:
if (found:=replaces.get(base.key)) is not None: return found
new_srcs = tuple(replace_uop(x, replaces) for x in base.src)
replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base
def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp:
if (found:=replaces.get(base)) is not None: return found
replaces[base] = ret = UOp(base.op, base.dtype, tuple(replace_uop(x, replaces) for x in base.src), base.arg)
return ret

def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]:
kernels = defaultdict(list)
for ctx in contexts:
if ctx.sink.op is UOps.CONST: continue
name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None
upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites]
kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx))
Expand Down

0 comments on commit b066ef2

Please sign in to comment.