Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add imports for BERT contrib operators #10949

Merged
merged 17 commits into from
Apr 13, 2022
Prev Previous commit
Next Next commit
SkipLayerNormalization
  • Loading branch information
altanh committed Apr 8, 2022
commit 1d3064eb257e87fdf0ab0f1671ab3782e96d70dd
33 changes: 30 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def _impl_v1(cls, inputs, attr, params):
mask = inputs[7]
altanh marked this conversation as resolved.
Show resolved Hide resolved
pos_ids = inputs[8]

eps = attr["epsilon"] if "epsilon" in attr else 0.0
eps = attr["epsilon"] if "epsilon" in attr else 1e-12
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved

(batch_size, seq_len) = infer_shape(input_ids)

Expand All @@ -877,16 +877,43 @@ def _impl_v1(cls, inputs, attr, params):
)
ln = _op.multiply(ln, gamma) + beta

# TODO: actually calculate this
mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
if mask:
# calculate number of words per sentence
mask_index = _op.sum(mask, axis=1)

return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)


class SkipLayerNormalization(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
breakpoint()
data = inputs[0]
skip = inputs[1]
gamma = inputs[2]
beta = inputs[3]
altanh marked this conversation as resolved.
Show resolved Hide resolved
bias = inputs[4]

eps = attr["epsilon"] if "epsilon" in attr else 1e-12

x = _op.add(data, skip)
if bias is not None:
x = _op.add(x, bias)

eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is basically same normalization calculation as 877-886 above right? if it's easy, can we pull it out into a common helper function?

if beta:
output = _op.add(output, beta)

placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc


return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)


class Attention(OnnxOpConverter):
Expand Down
199 changes: 199 additions & 0 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5433,6 +5433,205 @@ def verify_biasgelu(x, bias):
verify_biasgelu(x, bias)


@tvm.testing.parametrize_targets
def test_embedlayernormalization(target, dev):
def verify_embedlayernormalization(
input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta,
):
node = onnx.helper.make_node(
"EmbedLayerNormalization",
inputs=[
"input_ids",
"segment_ids",
"word_embedding",
"position_embedding",
"segment_embedding",
"gamma",
"beta",
],
outputs=["output", "mask_index", "embedding_sum"],
domain="com.microsoft",
)

node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))

graph = helper.make_graph(
[node],
"embedlayernormalization_test",
inputs=[
helper.make_tensor_value_info(
"input_ids", TensorProto.INT32, list(input_ids.shape)
),
helper.make_tensor_value_info(
"segment_ids", TensorProto.INT32, list(segment_ids.shape)
),
helper.make_tensor_value_info(
"word_embedding", TensorProto.FLOAT, list(word_embedding.shape)
),
helper.make_tensor_value_info(
"position_embedding", TensorProto.FLOAT, list(position_embedding.shape)
),
helper.make_tensor_value_info(
"segment_embedding", TensorProto.FLOAT, list(segment_embedding.shape)
),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
],
outputs=[
helper.make_tensor_value_info(
"output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size))
),
helper.make_tensor_value_info(
"mask_index", TensorProto.INT32, [batch_size]
),
helper.make_tensor_value_info(
"embedding_sum", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size))
),
],
)

model = helper.make_model(graph, producer_name="embedlayernormalization_test")
verify_with_ort_with_inputs(
model,
[
input_ids,
segment_ids,
word_embedding,
position_embedding,
segment_embedding,
gamma,
beta,
],
[(batch_size, sequence_length, hidden_size), batch_size, (batch_size, sequence_length, hidden_size)],
target=target,
dev=dev,
rtol=1e-4,
atol=1e-4,
)

hidden_size = 384
batch_size = 4
sequence_length = 4
vocab_size = 5

input_ids = np.full((batch_size, sequence_length), 3).astype("int32")
segment_ids = np.zeros((batch_size, sequence_length)).astype("int32")
word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32")
position_embedding = np.full((sequence_length, hidden_size), 2).astype("float32")
segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32")

gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32")
AndrewZhaoLuo marked this conversation as resolved.
Show resolved Hide resolved
beta = np.random.randn(hidden_size).astype("float32") * 0.1

verify_embedlayernormalization(
input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta
)


@tvm.testing.parametrize_targets
def test_attention(target, dev):
def verify_attention(input, weight, bias, mask_index, num_heads):
node = onnx.helper.make_node(
"Attention",
inputs=["input", "weight", "bias", "mask_index"],
outputs=["output", "present"],
domain="com.microsoft",
num_heads=num_heads,
)

present_output_shape = (2, batch_size, num_heads, sequence_length, head_size)

graph = helper.make_graph(
[node],
"attention_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)),
helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
helper.make_tensor_value_info(
"mask_index", TensorProto.INT32, list(mask_index.shape)
),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
helper.make_tensor_value_info(
"present", TensorProto.FLOAT, list(present_output_shape)
),
],
)

model = helper.make_model(graph, producer_name="attention_test")

# "present" output should be nullptr when the "past" input isn't included,
# but ort requires an output shape to be specified?
verify_with_ort_with_inputs(
model,
[input, weight, bias, mask_index],
[input.shape, present_output_shape],
target=target,
dev=dev,
rtol=1e-4,
atol=1e-4,
)

hidden_size = 384
batch_size = 4
sequence_length = 4
num_heads = 12
head_size = 32

dtype = "float32"
input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1
bias = np.random.randn(3 * hidden_size).astype(dtype)
mask_index = np.full((batch_size, sequence_length), 1).astype("int32")

verify_attention(input, weight, bias, mask_index, num_heads)


def test_skiplayernormalization(target, dev):
def verify_skiplayernormalization(input, skip, gamma, beta, bias):
node = onnx.helper.make_node(
"SkipLayerNormalization",
inputs=["input", "skip", "gamma", "beta", "bias"],
outputs=["output"],
domain="com.microsoft",
)

node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4))

graph = helper.make_graph(
[node],
"skiplayernormalization_test",
inputs=[
helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)),
helper.make_tensor_value_info("skip", TensorProto.FLOAT, list(skip.shape)),
helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)),
helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)),
helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)),
],
outputs=[
helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)),
],
)

model = helper.make_model(graph, producer_name="skiplayernormalization_test")
verify_with_ort_with_inputs(model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev)

hidden_size = 384
batch_size = 4
sequence_length = 4

dtype = "float32"
input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
skip = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype)
gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype)
beta = np.random.randn(hidden_size).astype(dtype) * 0.1
bias = np.random.randn(hidden_size).astype(dtype)

verify_skiplayernormalization(input, skip, gamma, beta, bias)


@tvm.testing.known_failing_targets("cuda")
@tvm.testing.parametrize_targets
def test_qlinearconv(target, dev):
Expand Down