diff --git a/python/triton/code_gen.py b/python/triton/code_gen.py index c0743c268b27..ae69f876c3bd 100644 --- a/python/triton/code_gen.py +++ b/python/triton/code_gen.py @@ -659,9 +659,12 @@ def parse(self): def __call__(self, *args, generator: CodeGenerator, **meta): try: + gscope = generator.gscope.copy() lscope = generator.lscope.copy() values = generator.module.get_values().copy() + generator.gscope = sys.modules[self.fn.__module__].__dict__ ret = generator.visit_FunctionDef(self.parse().body[0], inline=True, arg_values=args) + generator.gscope = gscope generator.lscope = lscope generator.module.set_values(values) return ret