Skip to content

Commit b941ab2

Browse files
authored
[test] Removed warning in test modules (#411)
1 parent 021953e commit b941ab2

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

test/modules/op/bmm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def get_example_inputs(self):
3535
class SimpleSingleBatchLhsConstBmm(TestModuleBase):
3636
def __init__(self):
3737
super().__init__()
38-
self.const_lhs = torch.randn(1, 4, 5)
38+
const_lhs = torch.randn(1, 4, 5)
39+
self.register_buffer("const_lhs", const_lhs)
3940

4041
def forward(self, rhs):
4142
z = torch.bmm(self.const_lhs, rhs)

test/modules/op/mm.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def get_example_inputs(self):
3535
class SimpleMatmulConstRhs(TestModuleBase):
3636
def __init__(self):
3737
super().__init__()
38-
self.weight = torch.randn(4, 5)
38+
weight = torch.randn(4, 5)
39+
self.register_buffer("weight", weight)
3940

4041
def forward(self, lhs):
4142
out = torch.mm(lhs, self.weight)
@@ -49,7 +50,8 @@ def get_example_inputs(self):
4950
class SimpleMatmulConstRhsOnert(TestModuleBase):
5051
def __init__(self):
5152
super().__init__()
52-
self.weight = torch.randn(4, 5)
53+
weight = torch.randn(4, 5)
54+
self.register_buffer("weight", weight)
5355

5456
def forward(self, lhs):
5557
out = torch.mm(lhs, self.weight)
@@ -80,7 +82,8 @@ def get_example_inputs(self):
8082
class SimpleMatmulConstLhsOnertWithLinearConversion(TestModuleBase):
8183
def __init__(self):
8284
super().__init__()
83-
self.weight = torch.randn(3, 4)
85+
weight = torch.randn(3, 4)
86+
self.register_buffer("weight", weight)
8487

8588
def forward(self, rhs):
8689
out = torch.mm(self.weight, rhs)

0 commit comments

Comments
 (0)