Skip to content

relay.build fails with NHWC not supported using OpenCL on intel #4808

Closed
@tom-gall

Description

@tom-gall

Hi!

New to TVM yet, so pointers to the fine manual are certainly much appreciated.

I'm trying to use OpenCL on intel. x86_64 linux.

Failure reported is :
File "./testmobilenet-opencl.py", line 43, in
graph, lib, params = relay.build(mod, target, target_host, params=params)

File "/home/tgall/tvm/tvm/python/tvm/relay/build_module.py", line 244, in build
graph_json, mod, params = bld_mod.build(func, target, target_host, params)

File "/home/tgall/tvm/tvm/python/tvm/relay/build_module.py", line 109, in build
self._build(func, target, target_host)

File "/home/tgall/tvm/tvm/python/tvm/_ffi/_ctypes/function.py", line 207, in call
raise get_last_ffi_error()

ValueError: Traceback (most recent call last):
[bt] (8) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr(tvm::RelayExpr const&)+0x9e) [0x7f17a28aa4ae]
[bt] (7) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x82) [0x7f17a28a8662]
[bt] (6) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>)#6}::FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>)+0x2c) [0x7f17a289acdc]
[bt] (5) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr
(tvm::relay::CallNode const*)+0x154) [0x7f17a28a5864]
[bt] (4) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr(tvm::RelayExpr const&)+0x9e) [0x7f17a28aa4ae]
[bt] (3) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)+0x82) [0x7f17a28a8662]
[bt] (2) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>::InitVTable()::{lambda(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>)#6}::FUN(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::Array<tvm::Tensor, void> (tvm::RelayExpr const&)>)+0x2c) [0x7f17a289acdc]
[bt] (1) /home/tgall/tvm/tvm/build/libtvm.so(tvm::relay::ScheduleGetter::VisitExpr
(tvm::relay::CallNode const*)+0x6ef) [0x7f17a28a5dff]
[bt] (0) /home/tgall/tvm/tvm/build/libtvm.so(+0x38ace3) [0x7f17a20d5ce3]
File "/home/tgall/tvm/tvm/python/tvm/_ffi/_ctypes/function.py", line 72, in cfun
rv = local_pyfunc(*pyargs)
File "/home/tgall/tvm/tvm/python/tvm/relay/op/nn/_nn.py", line 216, in compute_conv2d
dilation, layout, out_dtype)
File "</home/tgall/.local/lib/python3.7/site-packages/decorator.py:decorator-gen-35>", line 2, in conv2d
File "/home/tgall/tvm/tvm/python/tvm/target.py", line 382, in dispatch_func
return dispatch_dict[k](*args, **kwargs)
File "</home/tgall/.local/lib/python3.7/site-packages/decorator.py:decorator-gen-178>", line 2, in config_dispatcher
File "/home/tgall/tvm/tvm/python/tvm/autotvm/task/dispatcher.py", line 216, in dispatch_func
return dispatch_dict['direct'](cfg, *args, **kwargs)
File "/home/tgall/tvm/tvm/python/tvm/autotvm/task/topi_integration.py", line 400, in template_call
node = f(cfg, *args, **kwargs)
File "/home/tgall/tvm/tvm/topi/python/topi/cuda/conv2d.py", line 126, in conv2d_cuda
raise ValueError("not support this layout {} yet".format(layout))
ValueError: not support this layout NHWC yet

^^^^^^^^^^^

The code that reproduces this is (to me) a pretty boring test of MobileNetv1. I have an llvm version that works FWIW.

Python:

import os
import flatbuffers

import tvm
from tvm import relay

import tflite.Model

from PIL import Image

from matplotlib import pyplot as plt

import numpy as np
from tvm.contrib import graph_runtime as runtime

model_dir = os.path.dirname(".")
tflite_model_file = os.path.join(model_dir, "mobilenet_v1_1.0_224.tflite")
tflite_model_buf = open(tflite_model_file, "rb").read()

tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)

resized_image = Image.open('/home/tgall/tvm/opencl-exp/mobilenet/cat.png').resize((224, 224))

image_data = np.asarray(resized_image).astype("float32")
image_data = np.expand_dims(image_data, axis=0)

image_data[:, :, :, 0] = 2.0 / 255.0 * image_data[:, :, :, 0] - 1
image_data[:, :, :, 1] = 2.0 / 255.0 * image_data[:, :, :, 1] - 1
image_data[:, :, :, 2] = 2.0 / 255.0 * image_data[:, :, :, 2] - 1
print('input', image_data.shape)

input_tensor = "input"
input_shape = (1, 224, 224, 3)
input_dtype = "float32"

mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict={input_tensor: input_shape},
dtype_dict={input_tensor: input_dtype})

target = "opencl"
target_host = "llvm"

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target, target_host, params=params)

module = runtime.create(graph, lib, tvm.cl(0))
module.set_input(input_tensor, tvm.nd.array(image_data))

module.set_input(**params)

module.run()

tvm_output = module.get_output(0).asnumpy()

label_file = "labels_mobilenet_quant_v1_224.txt"
label_path = os.path.join(model_dir, label_file)

with open(label_path) as f:
labels = f.readlines()

predictions = np.squeeze(tvm_output)
prediction = np.argmax(predictions)

print("The image prediction result is: id " + str(prediction) + " name: " + labels[prediction])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions