File tree Expand file tree Collapse file tree 2 files changed +8
-4
lines changed
Expand file tree Collapse file tree 2 files changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -35,7 +35,8 @@ def get_example_inputs(self):
3535class SimpleSingleBatchLhsConstBmm (TestModuleBase ):
3636 def __init__ (self ):
3737 super ().__init__ ()
38- self .const_lhs = torch .randn (1 , 4 , 5 )
38+ const_lhs = torch .randn (1 , 4 , 5 )
39+ self .register_buffer ("const_lhs" , const_lhs )
3940
4041 def forward (self , rhs ):
4142 z = torch .bmm (self .const_lhs , rhs )
Original file line number Diff line number Diff line change @@ -35,7 +35,8 @@ def get_example_inputs(self):
3535class SimpleMatmulConstRhs (TestModuleBase ):
3636 def __init__ (self ):
3737 super ().__init__ ()
38- self .weight = torch .randn (4 , 5 )
38+ weight = torch .randn (4 , 5 )
39+ self .register_buffer ("weight" , weight )
3940
4041 def forward (self , lhs ):
4142 out = torch .mm (lhs , self .weight )
@@ -49,7 +50,8 @@ def get_example_inputs(self):
4950class SimpleMatmulConstRhsOnert (TestModuleBase ):
5051 def __init__ (self ):
5152 super ().__init__ ()
52- self .weight = torch .randn (4 , 5 )
53+ weight = torch .randn (4 , 5 )
54+ self .register_buffer ("weight" , weight )
5355
5456 def forward (self , lhs ):
5557 out = torch .mm (lhs , self .weight )
@@ -80,7 +82,8 @@ def get_example_inputs(self):
8082class SimpleMatmulConstLhsOnertWithLinearConversion (TestModuleBase ):
8183 def __init__ (self ):
8284 super ().__init__ ()
83- self .weight = torch .randn (3 , 4 )
85+ weight = torch .randn (3 , 4 )
86+ self .register_buffer ("weight" , weight )
8487
8588 def forward (self , rhs ):
8689 out = torch .mm (self .weight , rhs )
You can’t perform that action at this time.
0 commit comments