Skip to content

[Bug] Softmax in TFLite converter is channel first instead of channel last #9078

@shashwat14

Description

@shashwat14

Compiled a TFLite toy model (conv + softmax) using TVM0.7 to generate a softmax output.

Expected behavior

The uint8 output distribution of the output should look something like (below is tflite runtime output)

tflite_hist

Actual behavior

tvm_hist

After I change axis to 3 here, I get output that is similar to the expected output.

Environment

OS: Ubuntu 20.04
TVM: 0.7

Steps to reproduce

deps: tvm 0.7, numpy, and matplotlib to generate the plot
The zip contains the tflite model. Feel free to use a random

tflite_model.tflite.zip
Here is the input numpy array

import tvm
import numpy as np
import tflite
import tflite_runtime.interpreter as tflite_interpreter

from tvm import relay, transform
from tvm.contrib import graph_runtime as runtime

import matplotlib.pyplot as plt


def generate_hist(x, title, filename):
    plt.title(title)
    plt.xlabel('uint8 value')
    plt.xlim([-1,256])
    plt.ylabel('frequency')
    H, bins = np.histogram(x.flatten(), bins=256, range=(0,256))
    plt.bar(bins[:-1], H)
    plt.savefig(filename)
    plt.clf()

# TFLite compilation
model_path = 'tflite_model.tflite'
interpreter = tflite_interpreter.Interpreter(model_path=model_path)

interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

with open('conv_softmax_input.bin', 'rb') as f:
    input_arr = np.load(f) * 255
input_data = input_arr.astype(np.uint8)

interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()

tflite_output1 = interpreter.get_tensor(output_details[0]['index'])


# TVM Compilation
tflite_model_buf = open(model_path, "rb").read()
tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)

input_dict={"input":(1,100,160,256)}
input_dtype = {"input":"uint8"}
mod, params = relay.frontend.from_tflite(
    tflite_model, shape_dict=input_dict, dtype_dict=input_dtype
)

target = "llvm"
with transform.PassContext(opt_level=2):
    lib = relay.build(mod, target, params=params)

#invokes pure TVM compilation
module = runtime.GraphModule(lib["default"](tvm.cpu()))
module.set_input("input", tvm.nd.array(input_data))
module.run()
tvm_output1 = module.get_output(0).asnumpy()

generate_hist(tflite_output1, 'tflite quant(uint8) hist', 'tflite_hist.png')
generate_hist(tvm_output1, 'tvm quant(uint8) hist', 'tvm_hist.png')

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions