Skip to content

Commit 110fd07

Browse files
mcr229facebook-github-bot
authored andcommitted
Dynamic Shapes (#2442)
Summary: Pull Request resolved: #2442 Only need to look at tester.py file for the tester changes. Change is from `.run_method().compare_outputs() ` to `.run_method_and_compare_outputs()` now if Tester is initialized with dynamic inputs, we will generate random dynamic inputs (according to the specification of the dynamic shapes) to run on the model. This allows us to test that the inputs fed into the model can be dynamic. We ad a num_runs to run_method_and_compare_outputs so that we can choose to run a number of different dynamic inputs with dynamic shapes. Reviewed By: digantdesai, kirklandsign Differential Revision: D54650121
1 parent 62a4dd3 commit 110fd07

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+298
-319
lines changed

backends/xnnpack/test/models/deeplab_v3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,5 @@ def test_fp32_dl3(self):
3636
.partition()
3737
.to_executorch()
3838
.serialize()
39-
.run_method()
40-
.compare_outputs()
39+
.run_method_and_compare_outputs()
4140
)

backends/xnnpack/test/models/edsr.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@ def test_fp32_edsr(self):
2525
.partition()
2626
.to_executorch()
2727
.serialize()
28-
.run_method()
29-
.compare_outputs()
28+
.run_method_and_compare_outputs()
3029
)
3130

3231
def test_qs8_edsr(self):
@@ -38,6 +37,5 @@ def test_qs8_edsr(self):
3837
.partition()
3938
.to_executorch()
4039
.serialize()
41-
.run_method()
42-
.compare_outputs()
40+
.run_method_and_compare_outputs()
4341
)

backends/xnnpack/test/models/emformer_rnnt.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ def __init__(self):
2121
self.rnnt = decoder.model
2222

2323
class Joiner(EmformerRnnt):
24-
def forward(self, predict_inputs):
25-
return self.rnnt.join(*predict_inputs)
24+
def forward(self, a, b, c, d):
25+
return self.rnnt.join(a, b, c, d)
2626

2727
def get_example_inputs(self):
2828
join_inputs = (
@@ -31,7 +31,7 @@ def get_example_inputs(self):
3131
torch.rand([1, 128, 1024]),
3232
torch.tensor([128]),
3333
)
34-
return (join_inputs,)
34+
return join_inputs
3535

3636
def test_fp32_emformer_joiner(self):
3737
joiner = self.Joiner()
@@ -43,21 +43,19 @@ def test_fp32_emformer_joiner(self):
4343
.check(["torch.ops.higher_order.executorch_call_delegate"])
4444
.to_executorch()
4545
.serialize()
46-
.run_method()
47-
.compare_outputs()
46+
.run_method_and_compare_outputs()
4847
)
4948

5049
class Predictor(EmformerRnnt):
51-
def forward(self, predict_inputs):
52-
return self.rnnt.predict(*predict_inputs)
50+
def forward(self, a, b):
51+
return self.rnnt.predict(a, b, None)
5352

5453
def get_example_inputs(self):
5554
predict_inputs = (
5655
torch.zeros([1, 128], dtype=int),
5756
torch.tensor([128], dtype=int),
58-
None,
5957
)
60-
return (predict_inputs,)
58+
return predict_inputs
6159

6260
@unittest.skip("T183426271")
6361
def test_fp32_emformer_predictor(self):
@@ -70,20 +68,19 @@ def test_fp32_emformer_predictor(self):
7068
.check(["torch.ops.higher_order.executorch_call_delegate"])
7169
.to_executorch()
7270
.serialize()
73-
.run_method()
74-
.compare_outputs()
71+
.run_method_and_compare_outputs()
7572
)
7673

7774
class Transcriber(EmformerRnnt):
78-
def forward(self, predict_inputs):
79-
return self.rnnt.transcribe(*predict_inputs)
75+
def forward(self, a, b):
76+
return self.rnnt.transcribe(a, b)
8077

8178
def get_example_inputs(self):
8279
transcribe_inputs = (
8380
torch.randn(1, 128, 80),
8481
torch.tensor([128]),
8582
)
86-
return (transcribe_inputs,)
83+
return transcribe_inputs
8784

8885
def test_fp32_emformer_transcriber(self):
8986
transcriber = self.Transcriber()
@@ -95,6 +92,5 @@ def test_fp32_emformer_transcriber(self):
9592
.check(["torch.ops.higher_order.executorch_call_delegate"])
9693
.to_executorch()
9794
.serialize()
98-
.run_method()
99-
.compare_outputs()
95+
.run_method_and_compare_outputs()
10096
)

backends/xnnpack/test/models/inception_v3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def test_fp32_ic3(self):
4242
.check_not(list(self.all_operators))
4343
.to_executorch()
4444
.serialize()
45-
.run_method()
46-
.compare_outputs()
45+
.run_method_and_compare_outputs()
4746
)
4847

4948
def test_qs8_ic3(self):
@@ -63,6 +62,5 @@ def test_qs8_ic3(self):
6362
.check_not(list(ops_after_quantization))
6463
.to_executorch()
6564
.serialize()
66-
.run_method()
67-
.compare_outputs()
65+
.run_method_and_compare_outputs()
6866
)

backends/xnnpack/test/models/inception_v4.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def test_fp32_ic4(self):
3939
.check_not(list(self.all_operators))
4040
.to_executorch()
4141
.serialize()
42-
.run_method()
43-
.compare_outputs()
42+
.run_method_and_compare_outputs()
4443
)
4544

4645
def test_qs8_ic4(self):
@@ -60,6 +59,5 @@ def test_qs8_ic4(self):
6059
.check_not(list(ops_after_quantization))
6160
.to_executorch()
6261
.serialize()
63-
.run_method()
64-
.compare_outputs()
62+
.run_method_and_compare_outputs()
6563
)

backends/xnnpack/test/models/llama2_et_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,5 @@ def _test(self, dtype: torch.dtype = torch.float):
4545
.dump_artifact()
4646
.to_executorch()
4747
.serialize()
48-
.run_method()
49-
.compare_outputs(atol=5e-2)
48+
.run_method_and_compare_outputs(atol=5e-2)
5049
)

backends/xnnpack/test/models/mobilebert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,5 @@ def test_fp32_mobilebert(self):
3838
.check_not(list(self.supported_ops))
3939
.to_executorch()
4040
.serialize()
41-
.run_method()
42-
.compare_outputs()
41+
.run_method_and_compare_outputs()
4342
)

backends/xnnpack/test/models/mobilenet_v2.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def test_fp32_mv2(self):
4040
.check_not(list(self.all_operators))
4141
.to_executorch()
4242
.serialize()
43-
.run_method()
44-
.compare_outputs()
43+
.run_method_and_compare_outputs()
4544
)
4645

4746
def test_qs8_mv2(self):
@@ -61,6 +60,5 @@ def test_qs8_mv2(self):
6160
.check_not(list(ops_after_quantization))
6261
.to_executorch()
6362
.serialize()
64-
.run_method()
65-
.compare_outputs()
63+
.run_method_and_compare_outputs()
6664
)

backends/xnnpack/test/models/mobilenet_v3.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def test_fp32_mv3(self):
4242
.check_not(list(self.all_operators))
4343
.to_executorch()
4444
.serialize()
45-
.run_method()
46-
.compare_outputs()
45+
.run_method_and_compare_outputs()
4746
)
4847

4948
def test_qs8_mv3(self):
@@ -63,6 +62,5 @@ def test_qs8_mv3(self):
6362
.check_not(list(ops_after_lowering))
6463
.to_executorch()
6564
.serialize()
66-
.run_method()
67-
.compare_outputs()
65+
.run_method_and_compare_outputs()
6866
)

backends/xnnpack/test/models/resnet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ def test_fp32_resnet18(self):
2323
.partition()
2424
.to_executorch()
2525
.serialize()
26-
.run_method()
27-
.compare_outputs()
26+
.run_method_and_compare_outputs()
2827
)
2928

3029
def test_qs8_resnet18(self):
@@ -37,6 +36,5 @@ def test_qs8_resnet18(self):
3736
.partition()
3837
.to_executorch()
3938
.serialize()
40-
.run_method()
41-
.compare_outputs()
39+
.run_method_and_compare_outputs()
4240
)

0 commit comments

Comments
 (0)