Skip to content

Commit

Permalink
[TVMC] Add test for quantized pytorch model
Browse files Browse the repository at this point in the history
As a follow up to apache#9417 and now that apache#9362 is resolved, this PR adds a
test to check quantized pytorch mobilenetv2 is converted correctly.

Change-Id: Iaf2d38ce71c008e0141a4a2536bd54c2c9f3fe3d
  • Loading branch information
lhutton1 committed Nov 8, 2021
1 parent 811312c commit 843e738
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/python/driver/tvmc/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,23 @@ def pytorch_resnet18(tmpdir_factory):
return model_file_name


@pytest.fixture(scope="session")
def pytorch_mobilenetv2_quantized(tmpdir_factory):
try:
import torch
import torchvision.models as models
except ImportError:
# Not all environments provide Pytorch, so skip if that's the case.
return ""
model = models.quantization.mobilenet_v2(quantize=True)
model_file_name = "{}/{}".format(tmpdir_factory.mktemp("data"), "mobilenet_v2_quantized.pth")
# Trace model into torchscript.
traced_cpu = torch.jit.trace(model, torch.randn(1, 3, 224, 224))
torch.jit.save(traced_cpu, model_file_name)

return model_file_name


@pytest.fixture(scope="session")
def onnx_resnet50():
base_url = "https://github.com/onnx/models/raw/master/vision/classification/resnet/model"
Expand Down
15 changes: 15 additions & 0 deletions tests/python/driver/tvmc/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,21 @@ def test_load_model__pth(pytorch_resnet18):
assert "layer1.0.conv1.weight" in tvmc_model.params.keys()


def test_load_quantized_model__pth(pytorch_mobilenetv2_quantized):
# some CI environments wont offer torch, so skip in case it is not present
pytest.importorskip("torch")
pytest.importorskip("torchvision")

tvmc_model = tvmc.load(pytorch_mobilenetv2_quantized, shape_dict={"input": [1, 3, 224, 224]})
assert type(tvmc_model) is TVMCModel
assert type(tvmc_model.mod) is IRModule
assert type(tvmc_model.params) is dict

# checking weights remain quantized and are not float32
for p in tvmc_model.params.values():
assert p.dtype in ["int8", "uint8", "int32"] # int32 for bias


def test_load_model___wrong_language__to_pytorch(tflite_mobilenet_v1_1_quant):
# some CI environments wont offer pytorch, so skip in case it is not present
pytest.importorskip("torch")
Expand Down

0 comments on commit 843e738

Please sign in to comment.