Skip to content

Commit

Permalink
Merge pull request huggingface#164 from AmosLewis/move
Browse files Browse the repository at this point in the history
Rewrite&Move tflite examples to up tank dir
  • Loading branch information
AmosLewis authored Jun 30, 2022
2 parents 8199ea1 + c1cde2e commit babd3d0
Show file tree
Hide file tree
Showing 44 changed files with 1,566 additions and 584 deletions.
47 changes: 28 additions & 19 deletions shark/shark_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ def __init__(
print("Error. No tank_url, No model name,Please input either one.")
return

self.workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
os.makedirs(self.workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {self.workdir}")
# use model name get dir.
self.model_name_dir = os.path.join(self.workdir, str(self.model_name))
if not os.path.exists(self.model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by "
"tank_url if provided. You can also manually to "
"download the model from shark_tank by yourself."
)
os.makedirs(self.model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {self.model_name_dir}")

# read inputs from json file
self.load_json_input()
# get milr model file
Expand All @@ -66,42 +81,36 @@ def get_inputs(self):
def load_json_input(self):
print("load json inputs")
if self.model_type in ["tflite-tosa"]:
input_url = self.tank_url + "/" + str(self.model_name) + "/" + "input.json"
input_file = "/".join([self.model_name_dir, str(self.input_json)])
if os.path.exists(input_file):
print("Input has been downloaded before.", input_file)
else:
print("Download input", input_url)
urllib.request.urlretrieve(input_url, input_file)

args = []
with open(self.input_json, "r") as f:
with open(input_file, "r") as f:
args = json.load(f)
self.inputs = [np.asarray(arg, dtype=self.input_type) for arg in args]
else:
print("No json input required for current model type. You could call setup_inputs(YOU_INPUTS).")
print("No json input required for current model type. " "You could call setup_inputs(YOU_INPUTS).")
return self.inputs

def load_mlir_model(self):
workdir = os.path.join(os.path.dirname(__file__), self.local_tank_dir)
os.makedirs(workdir, exist_ok=True)
print(f"TMP_MODEL_DIR = {workdir}")
# use model name get dir.
model_name_dir = os.path.join(workdir, str(self.model_name))
if not os.path.exists(model_name_dir):
print(
"Model has not been download."
"shark_downloader will automatically download by tank_url if provided."
" You can also manually to download the model from shark_tank by yourself."
)
os.makedirs(model_name_dir, exist_ok=True)
print(f"TMP_MODELNAME_DIR = {model_name_dir}")

if self.model_type in ["tflite-tosa"]:
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tflite.mlir"
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tfite.mlir"])
self.mlir_file = "/".join([self.model_name_dir, str(self.model_name) + "_tfite.mlir"])
elif self.model_type in ["tensorflow"]:
self.mlir_url = self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_tf.mlir"
self.mlir_file = "/".join([model_name_dir, str(self.model_name) + "_tf.mlir"])
self.mlir_file = "/".join([self.model_name_dir, str(self.model_name) + "_tf.mlir"])
elif self.model_type in ["torch", "jax", "mhlo", "tosa"]:
self.mlir_url = (
self.tank_url + "/" + str(self.model_name) + "/" + str(self.model_name) + "_" + str(self.model_type) + ".mlir"
)
self.mlir_file = "/".join(
[
model_name_dir,
self.model_name_dir,
str(self.model_name) + "_" + str(self.model_type) + ".mlir",
]
)
Expand Down
2 changes: 2 additions & 0 deletions shark/tests/test_shark_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from shark.parser import shark_args
from shark.shark_inference import SharkInference
from shark.tflite_utils import TFLitePreprocessor
import sys

# model_path = "https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1?lite-format=tflite"

Expand Down Expand Up @@ -121,6 +122,7 @@ def create_and_check_module(self):


@pytest_param
@pytest.mark.xfail(sys.platform == "darwin", reason="known macos tflite install issue")
def test_albert(dynamic, device):
module_tester = AlbertTfliteModuleTester(dynamic=dynamic, device=device)
module_tester.create_and_check_module()
Expand Down
1 change: 0 additions & 1 deletion shark/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import csv
import urllib.request
from shark.iree_utils._common import IREE_TARGET_MAP
import json


Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from shark.shark_downloader import SharkDownloader
from shark.shark_inference import SharkInference
import pytest
Expand All @@ -21,26 +23,25 @@ def __init__(
def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
self.shark_downloader = SharkDownloader(
shark_downloader = SharkDownloader(
model_name="albert_lite_base",
tank_url="https://storage.googleapis.com/shark_tank",
local_tank_dir="./../gen_shark_tank",
model_type="tflite-tosa",
input_json="input.json",
input_type="int32",
)
tflite_tosa_model = self.shark_downloader.get_mlir_file()
inputs = self.shark_downloader.get_inputs()
self.shark_module = SharkInference(
tflite_tosa_model,
inputs,
tflite_tosa_model = shark_downloader.get_mlir_file()
inputs = shark_downloader.get_inputs()

shark_module = SharkInference(
mlir_module=tflite_tosa_model,
function_name="main",
device=self.device,
dynamic=self.dynamic,
jit_trace=True,
mlir_dialect="tflite",
)
self.shark_module.set_frontend("tflite-tosa")
self.shark_module.compile()
self.shark_module.forward(inputs)
shark_module.compile()
shark_module.forward(inputs)
# print(shark_results)


Expand All @@ -54,6 +55,9 @@ def setUp(self):
self.module_tester = AlbertTfliteModuleTester(self)
self.module_tester.save_mlir = self.save_mlir

import sys

@pytest.mark.xfail(sys.platform == "darwin", reason="known macos tflite install issue")
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def compare_results(mlir_results, tflite_results, details):
tflite_result = tflite_results[i]
mlir_result = mlir_result.astype(np.single)
tflite_result = tflite_result.astype(np.single)
print("mlir_result.shape", mlir_result.shape)
print("tflite_result.shape", tflite_result.shape)
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
max_error = np.max(np.abs(mlir_result - tflite_result))
print("Max error (%d): %f", i, max_error)
Expand All @@ -58,12 +60,14 @@ def __init__(
def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")

# Preprocess to get SharkImporter input args
tflite_preprocessor = TFLitePreprocessor(model_name="albert_lite_base")
raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()

# Use SharkImporter to get SharkInference input args
my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
Expand All @@ -72,6 +76,7 @@ def create_and_check_module(self):
)
mlir_model, func_name = my_shark_importer.import_mlir()

# Use SharkInference to get inference result
shark_module = SharkInference(
mlir_module=mlir_model,
function_name=func_name,
Expand Down Expand Up @@ -119,6 +124,9 @@ def setUp(self):
self.module_tester = AlbertTfliteModuleTester(self)
self.module_tester.save_mlir = self.save_mlir

import sys

@pytest.mark.xfail(sys.platform == "darwin", reason="known macos tflite install issue")
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import numpy as np
from shark.shark_importer import SharkImporter
from shark.shark_inference import SharkInference
import pytest
import unittest
from shark.parser import shark_args
from shark.tflite_utils import TFLitePreprocessor


# model_path = "https://tfhub.dev/google/lite-model/magenta/arbitrary-image-stylization-v1-256/int8/prediction/1?lite-format=tflite"


def compare_results(mlir_results, tflite_results, details):
print("Compare mlir_results VS tflite_results: ")
assert len(mlir_results) == len(tflite_results), "Number of results do not match"
for i in range(len(details)):
mlir_result = mlir_results[i]
tflite_result = tflite_results[i]
mlir_result = mlir_result.astype(np.single)
tflite_result = tflite_result.astype(np.single)
mlir_result = np.expand_dims(mlir_result, axis=0)
print("mlir_result.shape", mlir_result.shape)
print("tflite_result.shape", tflite_result.shape)
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
max_error = np.max(np.abs(mlir_result - tflite_result))
print("Max error (%d): %f", i, max_error)


class ArbitraryImageStylizationV1TfliteModuleTester:
def __init__(
self,
dynamic=False,
device="cpu",
save_mlir=False,
save_vmfb=False,
):
self.dynamic = dynamic
self.device = device
self.save_mlir = save_mlir
self.save_vmfb = save_vmfb

def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb

tflite_preprocessor = TFLitePreprocessor(model_name="arbitrary-image-stylization-v1-256")

raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()

my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
mlir_model, func_name = my_shark_importer.import_mlir()

shark_module = SharkInference(
mlir_module=mlir_model,
function_name=func_name,
device=self.device,
mlir_dialect="tflite",
)
# Case1: Use shark_importer default generate inputs
shark_module.compile()
mlir_results = shark_module.forward(inputs)
## post process results for compare
input_details, output_details = tflite_preprocessor.get_model_details()
mlir_results = list(mlir_results)
for i in range(len(output_details)):
dtype = output_details[i]["dtype"]
mlir_results[i] = mlir_results[i].astype(dtype)
tflite_results = tflite_preprocessor.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# print(mlir_results)


class ArbitraryImageStylizationV1TfliteModuleTest(unittest.TestCase):
@pytest.fixture(autouse=True)
def configure(self, pytestconfig):
self.save_mlir = pytestconfig.getoption("save_mlir")
self.save_vmfb = pytestconfig.getoption("save_vmfb")

def setUp(self):
self.module_tester = ArbitraryImageStylizationV1TfliteModuleTester(self)
self.module_tester.save_mlir = self.save_mlir

import sys

@pytest.mark.xfail(sys.platform == "darwin", reason="known macos tflite install issue")
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
self.module_tester.create_and_check_module()


if __name__ == "__main__":
# module_tester = ArbitraryImageStylizationV1TfliteModuleTester()
# module_tester.save_mlir = True
# module_tester.save_vmfb = True
# module_tester.create_and_check_module()

unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import urllib.request
from PIL import Image
from shark.tflite_utils import TFLitePreprocessor


# model_path = "https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/3?lite-format=tflite"
Expand Down Expand Up @@ -36,6 +37,9 @@ def compare_results(mlir_results, tflite_results, details):
tflite_result = tflite_results[i]
mlir_result = mlir_result.astype(np.single)
tflite_result = tflite_result.astype(np.single)
mlir_result = np.expand_dims(mlir_result, axis=0)
print("mlir_result.shape", mlir_result.shape)
print("tflite_result.shape", tflite_result.shape)
assert mlir_result.shape == tflite_result.shape, "shape doesnot match"
max_error = np.max(np.abs(mlir_result - tflite_result))
print("Max error (%d): %f", i, max_error)
Expand All @@ -57,33 +61,52 @@ def __init__(
def create_and_check_module(self):
shark_args.save_mlir = self.save_mlir
shark_args.save_vmfb = self.save_vmfb
my_shark_importer = SharkImporter(model_name="birds_V1", model_type="tflite")

mlir_model = my_shark_importer.get_mlir_model()
inputs = my_shark_importer.get_inputs()
shark_module = SharkInference(mlir_model, inputs, device=self.device, dynamic=self.dynamic)
shark_module.set_frontend("tflite-tosa")
tflite_preprocessor = TFLitePreprocessor(model_name="birds_V1")

raw_model_file_path = tflite_preprocessor.get_raw_model_file()
inputs = tflite_preprocessor.get_inputs()
tflite_interpreter = tflite_preprocessor.get_interpreter()

my_shark_importer = SharkImporter(
module=tflite_interpreter,
inputs=inputs,
frontend="tflite",
raw_model_file=raw_model_file_path,
)
mlir_model, func_name = my_shark_importer.import_mlir()

shark_module = SharkInference(
mlir_module=mlir_model,
function_name=func_name,
device=self.device,
mlir_dialect="tflite",
)

# Case1: Use shark_importer default generate inputs
shark_module.compile()
mlir_results = shark_module.forward(inputs)
## post process results for compare
input_details, output_details = my_shark_importer.get_model_details()
input_details, output_details = tflite_preprocessor.get_model_details()
mlir_results = list(mlir_results)
for i in range(len(output_details)):
dtype = output_details[i]["dtype"]
mlir_results[i] = mlir_results[i].astype(dtype)
tflite_results = my_shark_importer.get_raw_model_output()
tflite_results = tflite_preprocessor.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)

# Case2: Use manually set inputs
input_details, output_details = my_shark_importer.get_model_details()
input_details, output_details = tflite_preprocessor.get_model_details()
inputs = generate_inputs(input_details) # device_inputs
shark_module = SharkInference(mlir_model, inputs, device=self.device, dynamic=self.dynamic)
shark_module.set_frontend("tflite-tosa")
shark_module = SharkInference(
mlir_module=mlir_model,
function_name=func_name,
device=self.device,
mlir_dialect="tflite",
)
shark_module.compile()
mlir_results = shark_module.forward(inputs)
tflite_results = my_shark_importer.get_raw_model_output()
tflite_results = tflite_preprocessor.get_raw_model_output()
compare_results(mlir_results, tflite_results, output_details)
# print(mlir_results)

Expand All @@ -98,6 +121,9 @@ def setUp(self):
self.module_tester = BirdsV1TfliteModuleTester(self)
self.module_tester.save_mlir = self.save_mlir

import sys

@pytest.mark.xfail(sys.platform == "darwin", reason="known macos tflite install issue")
def test_module_static_cpu(self):
self.module_tester.dynamic = False
self.module_tester.device = "cpu"
Expand Down
Loading

0 comments on commit babd3d0

Please sign in to comment.