-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Closed
Labels
Description
Compiled a TFLite toy model (conv + softmax) using TVM0.7 to generate a softmax output.
Expected behavior
The uint8 output distribution of the output should look something like (below is tflite runtime output)
Actual behavior
After I change axis to 3 here, I get output that is similar to the expected output.
Environment
OS: Ubuntu 20.04
TVM: 0.7
Steps to reproduce
deps: tvm 0.7, numpy, and matplotlib to generate the plot
The zip contains the tflite model. Feel free to use a random
tflite_model.tflite.zip
Here is the input numpy array
import tvm
import numpy as np
import tflite
import tflite_runtime.interpreter as tflite_interpreter
from tvm import relay, transform
from tvm.contrib import graph_runtime as runtime
import matplotlib.pyplot as plt
def generate_hist(x, title, filename):
plt.title(title)
plt.xlabel('uint8 value')
plt.xlim([-1,256])
plt.ylabel('frequency')
H, bins = np.histogram(x.flatten(), bins=256, range=(0,256))
plt.bar(bins[:-1], H)
plt.savefig(filename)
plt.clf()
# TFLite compilation
model_path = 'tflite_model.tflite'
interpreter = tflite_interpreter.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
with open('conv_softmax_input.bin', 'rb') as f:
input_arr = np.load(f) * 255
input_data = input_arr.astype(np.uint8)
interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
tflite_output1 = interpreter.get_tensor(output_details[0]['index'])
# TVM Compilation
tflite_model_buf = open(model_path, "rb").read()
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
input_dict={"input":(1,100,160,256)}
input_dtype = {"input":"uint8"}
mod, params = relay.frontend.from_tflite(
tflite_model, shape_dict=input_dict, dtype_dict=input_dtype
)
target = "llvm"
with transform.PassContext(opt_level=2):
lib = relay.build(mod, target, params=params)
#invokes pure TVM compilation
module = runtime.GraphModule(lib["default"](tvm.cpu()))
module.set_input("input", tvm.nd.array(input_data))
module.run()
tvm_output1 = module.get_output(0).asnumpy()
generate_hist(tflite_output1, 'tflite quant(uint8) hist', 'tflite_hist.png')
generate_hist(tvm_output1, 'tvm quant(uint8) hist', 'tvm_hist.png')