@@ -496,45 +496,35 @@ def f(x: Tensor):
496496
497497
498498def test_class_irmodule ():
499- # FIXME(@altanh): Python class method decorators are executed eagerly before the class
500- # decorator, which means each function is parsed in isolation. This means we cannot resolve
501- # global variables at parsing time (or indeed any undefined identifier), so we either need to
502- # 1. defer parsing in the function decorators (so that the ir_module decorator can populate
503- # global variables first), although this means non-IRModule uses of the function decorators
504- # will no longer return Function/PrimFunc but some kind of wrapper type. This could cause
505- # problems if we pass them directly to things that expect Function/PrimFuncs.
506- # 2. parse every undefined identifier to a placeholder node (e.g. "UndefinedVar"), and run an
507- # IRModule -> IRModule pass that tries to resolve identifiers.
508- src = """@tvm.script.ir_module
509- class MyModule:
510- @T.prim_func
511- def my_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
512- A = T.match_buffer(a, (128, 128))
513- B = T.match_buffer(b, (128, 128))
514- C = T.match_buffer(c, (128, 128))
515-
516- for i, j, k in T.grid(128, 128, 128):
517- with T.block():
518- vi, vj, vk = T.axis.remap("SSR", [i, j, k])
519- with T.init():
520- C[vi, vj] = 0.0
521- C[vi, vj] += A[vi, vk] * B[vj, vk]
499+ @tvm .script .ir_module
500+ class MyModule :
501+ @T .prim_func
502+ def my_matmul (a : T .handle , b : T .handle , c : T .handle ) -> None :
503+ A = T .match_buffer (a , (128 , 128 ))
504+ B = T .match_buffer (b , (128 , 128 ))
505+ C = T .match_buffer (c , (128 , 128 ))
522506
523- @R.function
524- def f(x: Tensor[(n, n), _]) -> Tensor:
525- return g(x)
507+ for i , j , k in T .grid (128 , 128 , 128 ):
508+ with T .block ():
509+ vi , vj , vk = T .axis .remap ("SSR" , [i , j , k ])
510+ with T .init ():
511+ C [vi , vj ] = 0.0
512+ C [vi , vj ] += A [vi , vk ] * B [vj , vk ]
526513
527- @R.function
528- def g(y : Tensor[(n, n), _]) -> Tensor:
529- return relax.call_dps((n, n), my_matmul, (y, y) )
514+ @R .function
515+ def f ( x : Tensor [(n , n ), _ ]) -> Tensor :
516+ return g ( x )
530517
531- @R.function
532- def h(x, y, z):
533- _ = my_matmul(x, y, z)
534- return z
535- """
518+ @R .function
519+ def g (y : Tensor [(n , n ), _ ]) -> Tensor :
520+ return relax .call_dps ((n , n ), my_matmul , (y , y ))
521+
522+ @R .function
523+ def h (x , y , z ):
524+ _ = my_matmul (x , y , z )
525+ return z
536526
537- my_module = tvm . script . relax . parser . from_source ( src )
527+ my_module = MyModule
538528 assert isinstance (my_module , tvm .IRModule )
539529
540530 var_f = my_module .get_global_var ("f" )
0 commit comments