Skip to content

Dynamic ViT #2476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions backends/xnnpack/test/models/deeplab_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,5 @@ def test_fp32_dl3(self):
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
6 changes: 2 additions & 4 deletions backends/xnnpack/test/models/edsr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ def test_fp32_edsr(self):
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

def test_qs8_edsr(self):
Expand All @@ -38,6 +37,5 @@ def test_qs8_edsr(self):
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
28 changes: 12 additions & 16 deletions backends/xnnpack/test/models/emformer_rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self):
self.rnnt = decoder.model

class Joiner(EmformerRnnt):
def forward(self, predict_inputs):
return self.rnnt.join(*predict_inputs)
def forward(self, a, b, c, d):
return self.rnnt.join(a, b, c, d)

def get_example_inputs(self):
join_inputs = (
Expand All @@ -31,7 +31,7 @@ def get_example_inputs(self):
torch.rand([1, 128, 1024]),
torch.tensor([128]),
)
return (join_inputs,)
return join_inputs

def test_fp32_emformer_joiner(self):
joiner = self.Joiner()
Expand All @@ -43,21 +43,19 @@ def test_fp32_emformer_joiner(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

class Predictor(EmformerRnnt):
def forward(self, predict_inputs):
return self.rnnt.predict(*predict_inputs)
def forward(self, a, b):
return self.rnnt.predict(a, b, None)

def get_example_inputs(self):
predict_inputs = (
torch.zeros([1, 128], dtype=int),
torch.tensor([128], dtype=int),
None,
)
return (predict_inputs,)
return predict_inputs

@unittest.skip("T183426271")
def test_fp32_emformer_predictor(self):
Expand All @@ -70,20 +68,19 @@ def test_fp32_emformer_predictor(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

class Transcriber(EmformerRnnt):
def forward(self, predict_inputs):
return self.rnnt.transcribe(*predict_inputs)
def forward(self, a, b):
return self.rnnt.transcribe(a, b)

def get_example_inputs(self):
transcribe_inputs = (
torch.randn(1, 128, 80),
torch.tensor([128]),
)
return (transcribe_inputs,)
return transcribe_inputs

def test_fp32_emformer_transcriber(self):
transcriber = self.Transcriber()
Expand All @@ -95,6 +92,5 @@ def test_fp32_emformer_transcriber(self):
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
6 changes: 2 additions & 4 deletions backends/xnnpack/test/models/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_fp32_ic3(self):
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

def test_qs8_ic3(self):
Expand All @@ -63,6 +62,5 @@ def test_qs8_ic3(self):
.check_not(list(ops_after_quantization))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
6 changes: 2 additions & 4 deletions backends/xnnpack/test/models/inception_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ def test_fp32_ic4(self):
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

def test_qs8_ic4(self):
Expand All @@ -60,6 +59,5 @@ def test_qs8_ic4(self):
.check_not(list(ops_after_quantization))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
3 changes: 1 addition & 2 deletions backends/xnnpack/test/models/llama2_et_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,5 @@ def _test(self, dtype: torch.dtype = torch.float):
.dump_artifact()
.to_executorch()
.serialize()
.run_method()
.compare_outputs(atol=5e-2)
.run_method_and_compare_outputs(atol=5e-2)
)
3 changes: 1 addition & 2 deletions backends/xnnpack/test/models/mobilebert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@ def test_fp32_mobilebert(self):
.check_not(list(self.supported_ops))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)
23 changes: 17 additions & 6 deletions backends/xnnpack/test/models/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,15 @@ class TestMobileNetV2(unittest.TestCase):
}

def test_fp32_mv2(self):
dynamic_shapes = (
{
2: torch.export.Dim("height", min=224, max=455),
3: torch.export.Dim("width", min=224, max=455),
},
)

(
Tester(self.mv2, self.model_inputs)
Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes)
.export()
.to_edge()
.check(list(self.all_operators))
Expand All @@ -40,8 +46,7 @@ def test_fp32_mv2(self):
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)

def test_qs8_mv2(self):
Expand All @@ -50,8 +55,15 @@ def test_qs8_mv2(self):
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
}

dynamic_shapes = (
{
2: torch.export.Dim("height", min=224, max=455),
3: torch.export.Dim("width", min=224, max=455),
},
)

(
Tester(self.mv2, self.model_inputs)
Tester(self.mv2, self.model_inputs, dynamic_shapes=dynamic_shapes)
.quantize(Quantize(calibrate=False))
.export()
.to_edge()
Expand All @@ -61,6 +73,5 @@ def test_qs8_mv2(self):
.check_not(list(ops_after_quantization))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs(num_runs=10)
)
16 changes: 10 additions & 6 deletions backends/xnnpack/test/models/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ class TestMobileNetV3(unittest.TestCase):
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
mv3 = mv3.eval()
model_inputs = (torch.ones(1, 3, 224, 224),)
dynamic_shapes = (
{
2: torch.export.Dim("height", min=224, max=455),
3: torch.export.Dim("width", min=224, max=455),
},
)

all_operators = {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
Expand All @@ -33,7 +39,7 @@ class TestMobileNetV3(unittest.TestCase):

def test_fp32_mv3(self):
(
Tester(self.mv3, self.model_inputs)
Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes)
.export()
.to_edge()
.check(list(self.all_operators))
Expand All @@ -42,8 +48,7 @@ def test_fp32_mv3(self):
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs(num_runs=5)
)

def test_qs8_mv3(self):
Expand All @@ -53,7 +58,7 @@ def test_qs8_mv3(self):
ops_after_lowering = self.all_operators

(
Tester(self.mv3, self.model_inputs)
Tester(self.mv3, self.model_inputs, dynamic_shapes=self.dynamic_shapes)
.quantize(Quantize(calibrate=False))
.export()
.to_edge()
Expand All @@ -63,6 +68,5 @@ def test_qs8_mv3(self):
.check_not(list(ops_after_lowering))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs(num_runs=5)
)
68 changes: 51 additions & 17 deletions backends/xnnpack/test/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,63 @@


class TestResNet18(unittest.TestCase):
def test_fp32_resnet18(self):
inputs = (torch.ones(1, 3, 224, 224),)
inputs = (torch.ones(1, 3, 224, 224),)
dynamic_shapes = (
{
2: torch.export.Dim("height", min=224, max=455),
3: torch.export.Dim("width", min=224, max=455),
},
)

class DynamicResNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18()

def forward(self, x):
x = torch.nn.functional.interpolate(
x,
size=(224, 224),
mode="bilinear",
align_corners=True,
antialias=False,
)
return self.model(x)

def _test_exported_resnet(self, tester):
(
Tester(torchvision.models.resnet18(), inputs)
.export()
tester.export()
.to_edge()
.partition()
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_mean_dim",
]
)
.check(["torch.ops.higher_order.executorch_call_delegate"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
.run_method_and_compare_outputs()
)

def test_fp32_resnet18(self):
self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs))

def test_qs8_resnet18(self):
inputs = (torch.ones(1, 3, 224, 224),)
(
Tester(torchvision.models.resnet18(), inputs)
.quantize(Quantize(calibrate=False))
.export()
.to_edge()
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize(
Quantize(calibrate=False)
)
self._test_exported_resnet(quantized_tester)

def test_fp32_resnet18_dynamic(self):
self._test_exported_resnet(
Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes)
)

def test_qs8_resnet18_dynamic(self):
self._test_exported_resnet(
Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize(
Quantize(calibrate=False)
)
)
Loading