Skip to content

Commit

Permalink
Add more torch hg model tests (nod-ai#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
AmosLewis authored Aug 4, 2022
1 parent 934f15e commit 90fddc6
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
109 changes: 109 additions & 0 deletions tank/bert-base-cased_torch/bert-base-cased_torch_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from shark.shark_inference import SharkInference
from shark.iree_utils._common import check_device_drivers, device_driver_info
from tank.model_utils import compare_tensors
from shark.shark_downloader import download_torch_model

import torch
import unittest
import numpy as np
import pytest


class BertBaseUncasedModuleTester:
def __init__(
self,
save_mlir=False,
save_vmfb=False,
benchmark=False,
):
self.save_mlir = save_mlir
self.save_vmfb = save_vmfb
self.benchmark = benchmark

def create_and_check_module(self, dynamic, device):
model_mlir, func_name, input, act_out = download_torch_model(
"bert-base-cased", dynamic
)

# from shark.shark_importer import SharkImporter
# mlir_importer = SharkImporter(
# model,
# (input,),
# frontend="torch",
# )
# minilm_mlir, func_name = mlir_importer.import_mlir(
# is_dynamic=dynamic, tracing_required=True
# )

shark_module = SharkInference(
model_mlir,
func_name,
device=device,
mlir_dialect="linalg",
is_benchmark=self.benchmark,
)
shark_module.compile()
results = shark_module.forward(input)
assert True == compare_tensors(act_out, results)

if self.benchmark == True:
shark_module.shark_runner.benchmark_all_csv(
(input),
"bert-base-cased",
dynamic,
device,
"torch",
)


class BertBaseUncasedModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.module_tester = BertBaseUncasedModuleTester(self)
self.module_tester.benchmark = pytestconfig.getoption("benchmark")

def test_module_static_cpu(self):
dynamic = False
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)

def test_module_dynamic_cpu(self):
dynamic = True
device = "cpu"
self.module_tester.create_and_check_module(dynamic, device)

@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_static_gpu(self):
dynamic = False
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)

@pytest.mark.skipif(
check_device_drivers("gpu"), reason=device_driver_info("gpu")
)
def test_module_dynamic_gpu(self):
dynamic = True
device = "gpu"
self.module_tester.create_and_check_module(dynamic, device)

@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_static_vulkan(self):
dynamic = False
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)

@pytest.mark.skipif(
check_device_drivers("vulkan"), reason=device_driver_info("vulkan")
)
def test_module_dynamic_vulkan(self):
dynamic = True
device = "vulkan"
self.module_tester.create_and_check_module(dynamic, device)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tank/pytorch/torch_model_list.csv
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ model_name, use_tracing, model_type
microsoft/MiniLM-L12-H384-uncased,True,hf
albert-base-v2,True,hf
bert-base-uncased,True,hf
bert-base-cased,True,hf
google/mobilebert-uncased,True,hf
alexnet,False,vision
resnet18,False,vision
Expand Down

0 comments on commit 90fddc6

Please sign in to comment.