From 434fa2000a211c9958d570b6369df3b41d93a97a Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Thu, 29 Aug 2024 21:00:22 -0400 Subject: [PATCH] Fix rewrite bug (#32) --- triton_viz/interpreter.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/triton_viz/interpreter.py b/triton_viz/interpreter.py index f447bbf..a572e1e 100644 --- a/triton_viz/interpreter.py +++ b/triton_viz/interpreter.py @@ -343,8 +343,6 @@ def wrapper(input, axis=None, keep_dims=False): def patch(): old_grid_executor_call = GridExecutor.__call__ old_jit_function_call = JITFunction.__call__ - # XXX(Keren): Temporarily disable rewriting of AST - old_rewrite_ast = InterpretedFunction._rewrite_ast old_create_make_range = interpreter_builder.create_make_range old_create_masked_load = interpreter_builder.create_masked_load old_create_expand_dims = interpreter_builder.create_expand_dims @@ -373,7 +371,6 @@ def patch(): finally: GridExecutor.__call__ = old_grid_executor_call JITFunction.__call__ = old_jit_function_call - InterpretedFunction._rewrite_ast = old_rewrite_ast interpreter_builder.create_make_range = old_create_make_range interpreter_builder.create_masked_load = old_create_masked_load interpreter_builder.create_expand_dims = old_create_expand_dims