Skip to content

Commit 156aa59

Browse files
jwfrommjroesch
authored andcommitted
[Relay][Frontend][ONNX] New Operators and Opsets to Support BERT (#4197)
* Added slice v10 * Added constantofshape operation and small refactor. * Finished one_hot implementation. * Reshape working across all bert layers. * Fixed constantofshape and removed code duplication. * onnx model fully ingested. * Working on improving onnx tests. * Changed onnx testing to use onnxruntime instead of caffe2, also formatted. * Add arbitrary output nodes to onnx frontend. * Added v6 tiling for bert squad 8 support. * Small syntax fixes * Reduced code duplication in split opset versions. * Added batch matmul test * Added unstack split testing. * Adde onehot test, needs a little cleanup probably. * Replaced deprecated constant fill with constantofshape and updated tests accordingly. * Added tests for new opset version of slice and tile. * lint clean up * Lint fixes * Changed onnx dependency * Went back to caffe2 runtime for CI integration. * Rebase and small typo/syntax changes. * Added hard casting of onehot attributes to int.
1 parent 71f39be commit 156aa59

File tree

4 files changed

+744
-381
lines changed

4 files changed

+744
-381
lines changed

python/tvm/relay/frontend/common.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
import logging
2020

2121
import tvm
22+
import numpy as np
2223
from topi.util import get_const_tuple
2324
from .. import expr as _expr
2425
from .. import module as _module
2526
from .. import transform as _transform
2627
from .. import op as _op
28+
from .. import analysis
2729

2830

2931
class RequiredAttr(object):
@@ -474,6 +476,50 @@ def infer_channels(inputs, transpose=False):
474476
return channels
475477

476478

479+
def infer_value(input_val, params):
480+
"""A hack for getting the value of an expression by evaluating a
481+
portion of the relay graph. This is often needed for functions that
482+
whose output shape depends on the value of a tensor.
483+
"""
484+
from tvm.contrib import graph_runtime
485+
# Check that all free variables have associated parameters.
486+
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
487+
input_val)), "All inputs to infer must be available in params."
488+
func = _expr.Function(analysis.free_vars(input_val), input_val)
489+
with tvm.relay.build_config(opt_level=0):
490+
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
491+
ctx = tvm.cpu(0)
492+
m = graph_runtime.create(graph, lib, ctx)
493+
m.set_input(**params)
494+
m.run()
495+
return m.get_output(0)
496+
497+
498+
def infer_value_simulated(input_val, params):
499+
"""Extention to infer_value that can be used when some input
500+
values are missing. This function creates dummy inputs with the same
501+
shape and random values then calls infer_value. This is helpful when
502+
implementing certain onnx operators where we need to evaluate the graph
503+
to determine a static shape.
504+
"""
505+
fake_params = []
506+
# Add a fake copy of all missing params.
507+
for free_param in analysis.free_vars(input_val):
508+
if free_param.name_hint not in params:
509+
fp_dtype = free_param.type_annotation.dtype
510+
fp_shape = [s.value for s in free_param.type_annotation.shape]
511+
fake_params.append(free_param)
512+
params[free_param.name_hint] = tvm.nd.array(
513+
np.random.rand(*fp_shape).astype(fp_dtype)
514+
)
515+
# Now infer the value.
516+
output_value = infer_value(input_val, params)
517+
# Clean fake params out of param dictionary.
518+
for fake_p in fake_params:
519+
params.pop(fake_p.name_hint, None)
520+
return output_value
521+
522+
477523
def new_var(name_hint,
478524
type_annotation=None,
479525
shape=None,

0 commit comments

Comments
 (0)