Skip to content

Commit

Permalink
re-enable gpt2 fusion tests (#8566)
Browse files Browse the repository at this point in the history
Re-enable tests that disabled in PR 8530
Update import of test_optimizer.py so that the test could run in source directory.
Add a parameter to disable symbolic shape inference in fp16 conversion since it throws exception for some model.
  • Loading branch information
tianleiwu authored Aug 6, 2021
1 parent 1b902d0 commit 44ff80e
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 80 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/python/tools/transformers/affinity_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def set_affinity(self):
if self.is_os_supported:
current_affinity = os.sched_getaffinity(self.pid)
if (self.affinity != current_affinity):
logger.warning("Replacing affinity setting %s with %s", str(current_affinity),
str(self.affinity))
logger.warning("Replacing affinity setting %s with %s", str(current_affinity), str(self.affinity))
os.sched_setaffinity(self.pid, self.affinity)


Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/python/tools/transformers/fusion_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
logger.debug("fuse_attention: failed to match v path")
return
(_, _, add_v, matmul_v) = v_nodes

is_distill = False
is_distill_add = False
qk_paths = {
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/python/tools/transformers/fusion_gpt_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
if i == 1:
add_qk = qk_nodes[1]
_, input_mask_nodes, _ = self.model.match_parent_paths(
add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0, 0]),
(['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0])],
output_name_to_node)
add_qk,
[
(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0, 0]),
(['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0]),
(['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0]), # useless cast and reshape are removed.
],
output_name_to_node) # yapf: disable
if input_mask_nodes is None:
logger.debug("fuse_attention: failed to match input attention mask path")
return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):

if self.shape_infer_helper is not None:
if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
logger.debug(f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same")
logger.debug(
f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same")
return
else:
# shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion
Expand Down
12 changes: 8 additions & 4 deletions onnxruntime/python/tools/transformers/fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def check_node_attribute(node, attribute_name: str, expected_value, default_valu
value = helper.get_attribute_value(attr)

if isinstance(expected_value, list):
return isinstance(value, ndarray) and array_equal(expected_value, value, equal_nan=False)
return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
expected_value, value, equal_nan=False)
else:
return value == expected_value

Expand All @@ -96,12 +97,13 @@ def check_node_input_value(self, node, input_index: int, expected_value):
value = self.model.get_constant_value(node.input[input_index])

if isinstance(expected_value, list):
return isinstance(value, ndarray) and array_equal(expected_value, value, equal_nan=False)
return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal(
expected_value, value, equal_nan=False)
else:
return value == expected_value

@staticmethod
def remove_useless_reshape_nodes(model:OnnxModel):
def remove_useless_reshape_nodes(model: OnnxModel):
"""Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape
"""
shape_infer = model.infer_runtime_shape(update=True)
Expand All @@ -114,7 +116,8 @@ def remove_useless_reshape_nodes(model:OnnxModel):
input_shape = shape_infer.get_edge_shape(node.input[0])
output_shape = shape_infer.get_edge_shape(node.output[0])
if input_shape and output_shape and input_shape == output_shape:
logger.info(f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}")
logger.info(
f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}")
nodes_to_remove.append(node)

if nodes_to_remove:
Expand All @@ -123,6 +126,7 @@ def remove_useless_reshape_nodes(model:OnnxModel):
model.remove_node(node)
model.prune_graph()


class NumpyHelper:
@staticmethod
def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/python/tools/transformers/onnx_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState
from quantize_helper import QuantizeHelper
from huggingface_models import MODEL_CLASSES

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

logger = logging.getLogger(__name__)
Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/python/tools/transformers/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,21 +486,24 @@ def change_input_output_float32_to_float16(self):
# restore opset version
self.model.opset_import[0].version = original_opset_version

def convert_model_float32_to_float16(self, cast_input_output=True):
def convert_model_float32_to_float16(self, cast_input_output=True, use_symbolic_shape_infer=True):
"""Convert a graph to FLOAT16. By default, we will keep data types of inputs and outputs.
For decoder model with past_key_values, it is recommended to set cast_input_output=False for better performance.
Args:
cast_input_output (bool, optional): keep data type of inputs and outputs, and add Cast nodes to convert float32 inputs to float16, and float16 to float32 for outputs. Defaults to True.
use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
"""
from packaging.version import Version
import onnxconverter_common as oc
if Version(oc.__version__) > Version("1.7.0"):
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference.
shape_infer_helper = SymbolicShapeInferenceHelper(self.model)
model_with_shape = shape_infer_helper.infer_shapes(self.model, auto_merge=True, guess_output_rank=False)
self.model = oc.float16.convert_float_to_float16(model_with_shape,
model = self.model
if use_symbolic_shape_infer:
# Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc) are not recognized by onnx shape inference.
shape_infer_helper = SymbolicShapeInferenceHelper(model)
model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
self.model = oc.float16.convert_float_to_float16(model,
keep_io_types=cast_input_output,
disable_shape_infer=True)
disable_shape_infer=use_symbolic_shape_infer)
return

graph = self.model.graph
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from collections import deque
from onnx import ModelProto, TensorProto, numpy_helper
from onnx_model_bert_tf import BertOnnxModelTF

logger = logging.getLogger(__name__)


Expand Down
39 changes: 20 additions & 19 deletions onnxruntime/test/python/transformers/test_attention_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,25 +96,26 @@ def test_3d_attention_fusion_tf2onnx_model(self):
expected = onnx.load(expected_model_path)
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))

# def test_gpt2_attention_fusion(self):
# hidden_size = 64
# num_heads = 4
# for add_order in [False, True]:
# model = create_gpt2_attention(hidden_size=hidden_size, num_heads=num_heads, switch_add_inputs=add_order)
# dir = '.'
# model_path = os.path.join(dir, "gpt2_attention.onnx")
# onnx.save(model, model_path)
# optimized_model = optimize_model(model_path,
# model_type='gpt2',
# num_heads=num_heads,
# hidden_size=hidden_size,
# disable_onnxruntime=True)
# os.remove(model_path)

# model_name = "gpt2_attention_{}.onnx".format("add_opt" if add_order else "opt")
# expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'models', model_name)
# expected = onnx.load(expected_model_path)
# self.assertEqual(str(optimized_model.model.graph), str(expected.graph))
def test_gpt2_attention_fusion(self):
hidden_size = 64
num_heads = 4
for add_order in [False, True]:
model = create_gpt2_attention(hidden_size=hidden_size, num_heads=num_heads, switch_add_inputs=add_order)
dir = '.'
model_path = os.path.join(dir, "gpt2_attention.onnx")
onnx.save(model, model_path)
optimized_model = optimize_model(model_path,
model_type='gpt2',
num_heads=num_heads,
hidden_size=hidden_size,
disable_onnxruntime=True)
optimized_model.topological_sort()
os.remove(model_path)

model_name = "gpt2_attention_{}.onnx".format("add_opt" if add_order else "opt")
expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'models', model_name)
expected = onnx.load(expected_model_path)
self.assertEqual(str(optimized_model.model.graph), str(expected.graph))


if __name__ == '__main__':
Expand Down
108 changes: 64 additions & 44 deletions onnxruntime/test/python/transformers/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,31 @@

import unittest
import os
import onnx
import onnxruntime
import pytest
from onnx import helper, TensorProto, ModelProto, load_model
from onnx.helper import make_node, make_tensor_value_info
import numpy as np
from onnx import numpy_helper
from onnx import TensorProto, load_model
import sys

from onnxruntime.transformers.optimizer import optimize_model, optimize_by_onnxruntime
from onnxruntime.transformers.onnx_model import OnnxModel
# Try import optimizer from source directory so that we need not build and install package after making change.
source_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'python', 'tools', 'transformers')
if (os.path.exists(source_dir) and source_dir not in sys.path):
sys.path.append(source_dir)
from optimizer import optimize_model
from onnx_model import OnnxModel
from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
from huggingface_models import MODELS
from benchmark_helper import Precision
else:
from onnxruntime.transformers.optimizer import optimize_model
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt
from onnxruntime.transformers.huggingface_models import MODELS
from onnxruntime.transformers.benchmark_helper import Precision


BERT_TEST_MODELS = {
"bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10
"bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11
"gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11
"bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10
"bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11
"gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11
"gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'),
"multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'),
"bert_tf2onnx_0": ('models', 'bert_tf2onnx_0.onnx')
Expand All @@ -35,10 +44,15 @@
def _get_test_model_path(name):
sub_dir, file = BERT_TEST_MODELS[name]
if sub_dir == "FUSION":
#return os.path.join('..', '..', '..', '..', 'test', 'testdata', 'transform', 'fusion', file)
return os.path.join('./', 'testdata', 'transform', 'fusion', file)
relative_path = os.path.join(os.path.dirname(__file__), '..', '..', 'testdata', 'transform', 'fusion', file)
if (os.path.exists(relative_path)):
return relative_path
return os.path.join('.', 'testdata', 'transform', 'fusion', file)
else:
return os.path.join('./', 'transformers', 'test_data', sub_dir, file)
relative_path = os.path.join(os.path.dirname(__file__), 'test_data', sub_dir, file)
if (os.path.exists(relative_path)):
return relative_path
return os.path.join('.', 'transformers', 'test_data', sub_dir, file)


class TestBertOptimization(unittest.TestCase):
Expand All @@ -63,9 +77,6 @@ def _test_optimizer_on_huggingface_model(self,
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
from onnx_exporter import export_onnx_model_from_pt
from huggingface_models import MODELS
from benchmark_helper import Precision

input_names = MODELS[model_name][0]

Expand Down Expand Up @@ -94,9 +105,6 @@ def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, i
# expect fusion result list have the following keys
# EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization
model_fusion_statistics = {}
from onnx_exporter import export_onnx_model_from_tf
from huggingface_models import MODELS
from benchmark_helper import Precision
print("testing mode ", model_name)
print("testing input number = ", inputs_count)
input_names = MODELS[model_name][0]
Expand Down Expand Up @@ -157,28 +165,28 @@ def test_gpt2_past(self):
}
self.verify_node_count(model, expected_node_count, 'test_gpt2_past')

# def test_gpt2_past_fp16(self):
# input_model_path = _get_test_model_path('gpt2_past')
# model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True))
# model.convert_model_float32_to_float16(cast_input_output=False)
# for input in model.graph().input[1:]:
# self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16)
# for output in model.graph().output:
# self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)

# def test_gpt2_past_mask(self):
# input = _get_test_model_path('gpt2_past_mask')
# model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
# expected_node_count = {
# 'EmbedLayerNormalization': 0,
# 'Attention': 1,
# 'Gelu': 0,
# 'FastGelu': 1,
# 'BiasGelu': 0,
# 'LayerNormalization': 2,
# 'SkipLayerNormalization': 0
# }
# self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask')
def test_gpt2_past_fp16(self):
input_model_path = _get_test_model_path('gpt2_past')
model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True))
model.convert_model_float32_to_float16(cast_input_output=False, use_symbolic_shape_infer=False)
for input in model.graph().input[1:]:
self.assertEqual(input.type.tensor_type.elem_type, TensorProto.FLOAT16)
for output in model.graph().output:
self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16)

def test_gpt2_past_mask(self):
input = _get_test_model_path('gpt2_past_mask')
model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4)
expected_node_count = {
'EmbedLayerNormalization': 0,
'Attention': 1,
'Gelu': 0,
'FastGelu': 1,
'BiasGelu': 0,
'LayerNormalization': 2,
'SkipLayerNormalization': 0
}
self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask')

def test_multiple_embed(self):
input_model_path = _get_test_model_path('multiple_embed')
Expand Down Expand Up @@ -209,9 +217,15 @@ def test_multiple_embed(self):
# self.verify_node_count(model, expected_node_count, 'test_bert_tf2onnx_0')

@pytest.mark.slow
def test_huggingface_bert_fusion(self):
def test_huggingface_bert_fusion_1(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=1)

@pytest.mark.slow
def test_huggingface_bert_fusion_2(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=2)

@pytest.mark.slow
def test_huggingface_bert_fusion_3(self):
self._test_optimizer_on_huggingface_model("bert-base-uncased", [1, 12, 0, 0, 12, 0, 24], inputs_count=3)

@pytest.mark.slow
Expand Down Expand Up @@ -269,9 +283,15 @@ def test_huggingface_bart_fusion(self):
self._test_optimizer_on_huggingface_model("facebook/bart-base", [0, 0, 0, 0, 12, 2, 30])

@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx(self):
def test_huggingface_bert_base_cased_from_tf2onnx_1(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 1)

@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_2(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 2)

@pytest.mark.slow
def test_huggingface_bert_base_cased_from_tf2onnx_3(self):
self._test_optimizer_on_tf_model("bert-base-cased", [0, 12, 0, 0, 0, 0, 25], 3)

@pytest.mark.slow
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/test/python/transformers/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ def run_profile(self, arguments: str):
results = run(args)
self.assertTrue(len(results) > 1)

@pytest.mark.slow
def test_profiler_gpu(self):
input_model_path = _get_test_model_path('bert_keras_squad')
if 'CUDAExecutionProvider' in onnxruntime.get_available_providers():
self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --use_gpu')

@pytest.mark.slow
def test_profiler_cpu(self):
input_model_path = _get_test_model_path('bert_keras_squad')
self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --dummy_inputs default')
Expand Down

0 comments on commit 44ff80e

Please sign in to comment.