From b066ef2282d848c3dc407d1e33433d7a14691b42 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 6 Oct 2024 12:00:55 +0300 Subject: [PATCH] small changes from the viz_rewrite branch [pr] (#6907) * simpler replace * dont show shapetracker consts * changed_nodes shouldn't exist for the first sink --- viz/index.html | 2 +- viz/serve.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/viz/index.html b/viz/index.html index ca2e45d37574..88a301730088 100644 --- a/viz/index.html +++ b/viz/index.html @@ -269,7 +269,7 @@ ret = await (await fetch(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`)).json(); cache[cacheKey] = ret; } - renderGraph(ret.graphs[currentRewrite], ret.changed_nodes[currentRewrite]); + renderGraph(ret.graphs[currentRewrite], currentRewrite == 0 ? [] : ret.changed_nodes[currentRewrite-1]); // ***** RHS metadata const metadata = document.querySelector(".container.metadata"); metadata.innerHTML = ""; diff --git a/viz/serve.py b/viz/serve.py index 616d5e34339a..b325959a2dd0 100755 --- a/viz/serve.py +++ b/viz/serve.py @@ -13,11 +13,11 @@ def reconstruct_graph(sink:UOp, rewrites:List[Tuple[UOp, UOp, UPat]]) -> Tuple[List[UOp], List[List[str]], List[List[int]]]: uops: List[UOp] = [sink] diffs: List[List[str]] = [] - changed_nodes: List[List[int]] = [[]] - seen_replaces: Dict[bytes, UOp] = {} + changed_nodes: List[List[int]] = [] + seen_replaces: Dict[UOp, UOp] = {} for i, (first, rewritten, _) in enumerate(rewrites): # first, rewrite this UOp with the current rewrite + all the seen rewrites before this - seen_replaces[first.key] = rewritten + seen_replaces[first] = rewritten new_sink = replace_uop(uops[-1], {**seen_replaces}) # sanity check assert new_sink is not uops[-1], f"rewritten sink wasn't rewritten! {i}\n{new_sink}\n{uops[-1]}" @@ -41,15 +41,15 @@ def uop_to_json(x:UOp) -> Dict[int, Tuple[str, str, List[int], str, str]]: graph[id(u)] = (label, str(u.dtype), [id(x) for x in u.src if x.op is not UOps.CONST], str(u.arg), uops_colors.get(u.op, "#ffffff")) return graph -def replace_uop(base:UOp, replaces:Dict[bytes, UOp]) -> UOp: - if (found:=replaces.get(base.key)) is not None: return found - new_srcs = tuple(replace_uop(x, replaces) for x in base.src) - replaces[base.key] = ret = UOp(base.op, base.dtype, new_srcs, base.arg) if new_srcs != base.src else base +def replace_uop(base:UOp, replaces:Dict[UOp, UOp]) -> UOp: + if (found:=replaces.get(base)) is not None: return found + replaces[base] = ret = UOp(base.op, base.dtype, tuple(replace_uop(x, replaces) for x in base.src), base.arg) return ret def load_kernels(contexts) -> DefaultDict[str, List[Tuple[GraphRewriteMetadata, TrackedRewriteContext]]]: kernels = defaultdict(list) for ctx in contexts: + if ctx.sink.op is UOps.CONST: continue name = to_function_name(ctx.kernel.name) if ctx.kernel is not None else None upats = [(upat.location, upat.printable()) for _,_,upat in ctx.rewrites] kernels[name].append((GraphRewriteMetadata(ctx.loc, lines(ctx.loc[0])[ctx.loc[1]-1].strip(), name, upats), ctx))