Skip to content

Commit

Permalink
Revert TRT4 support
Browse files Browse the repository at this point in the history
This reverts commit dc693c74b179800f19d376ff79857f95e3b637b0.
  • Loading branch information
MattConley authored and 2sin18 committed Jun 24, 2021
1 parent 9cd2834 commit 8bf4b13
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 60 deletions.
11 changes: 0 additions & 11 deletions tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,6 @@ Status TrtNodeValidator::ConvertConstToWeights(
return status;
}

#if IS_TRT_VERSION_GE(5, 0, 0, 0)
static void InitializeTrtPlugins() {
static mutex plugin_mutex(LINKER_INITIALIZED);
static bool plugin_initialized = false;
Expand Down Expand Up @@ -1209,16 +1208,13 @@ static void InitializeTrtPlugins() {
}
}
}
#endif

Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
TrtPrecisionMode precision_mode, bool use_calibration)
: trt_network_(trt_network),
precision_mode_(precision_mode),
use_calibration_(use_calibration) {
#if IS_TRT_VERSION_GE(5, 0, 0, 0)
InitializeTrtPlugins();
#endif
this->RegisterOpConverters();
}

Expand Down Expand Up @@ -4700,9 +4696,6 @@ Status ConvertMatMulHelper(OpConverterParams* params,
!transpose_a && input_a.is_tensor() && input_b.is_weights();
const bool should_use_fc = can_use_fc && input_a.GetTrtDims().nbDims >= 3 &&
input_b.GetTrtDims().nbDims == 2;
// If TRT < 5, a fully connected layer must be used as the matmul op is
// unsupported
#if IS_TRT_VERSION_GE(5, 0, 0, 0)
// If int8 is specified, FC must be used unless it is not compatible, as MM
// does not support int8 at this time.
if (should_use_fc || (can_use_fc && params->converter->precision_mode() ==
Expand Down Expand Up @@ -4753,10 +4746,6 @@ Status ConvertMatMulHelper(OpConverterParams* params,
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
return Status::OK();
#else
return ConvertFullyConnectedHelper(
params, input_a.tensor(), input_b.weights(), transpose_b, node_name);
#endif
}

// inputs are both two dimensional (ops::MatMul)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import numpy as np
import six

from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
Expand All @@ -56,7 +55,6 @@
from tensorflow.python.training.tracking import tracking
from tensorflow.python.util import nest


TfTrtIntegrationTestParams = namedtuple(
"TfTrtIntegrationTestParams",
[
Expand Down Expand Up @@ -91,11 +89,7 @@
FP32 = "FP32"
FP16 = "FP16"
INT8 = "INT8"
TRT_VERSION = get_linked_tensorrt_version()
if TRT_VERSION >= (5, 0, 0):
PRECISION_MODES = [FP32, FP16, INT8]
else:
PRECISION_MODES = [FP32, FP16]
PRECISION_MODES = [FP32, FP16, INT8]


def IsQuantizationMode(mode):
Expand Down Expand Up @@ -143,16 +137,13 @@ def OptimizerDisabledRewriterConfig():
class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration."""

TRT_VERSION = get_linked_tensorrt_version()
@property
def trt_incompatible_op(self):
return math_ops.erf

@property
def precision_modes(self):
if TRT_VERSION >= (5, 0, 0):
return ["FP32", "FP16", "INT8"]
return ["FP32", "FP16"]
return ["FP32", "FP16", "INT8"]

# str is bytes in py2, but unicode in py3.
def _ToUnicode(self, s):
Expand Down Expand Up @@ -791,17 +782,12 @@ def _GetTestConfigsV1():
# whether to run specific ones with ShouldRunTest().
#
# Note: INT8 without calibration behaves like FP32/FP16.
if TRT_VERSION >= (5, 0, 0):
opts = list(
opts = list(
itertools.product([FP32, FP16, INT8], [convert_online, convert_offline],
[dynamic_engine, static_engine], [no_calibration]))
# We always run calibration with offline tool.
# TODO(aaroey): static calibration engine is not supported yet.
opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
else:
opts = list(
itertools.product([FP32, FP16], [convert_online, convert_offline],
[dynamic_engine, static_engine], [no_calibration]))
# We always run calibration with offline tool.
# TODO(aaroey): static calibration engine is not supported yet.
opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
return opts


Expand All @@ -822,17 +808,12 @@ def _GetTestConfigsV2():
# - For simplicity we don't test online conversion which requires setting the
# Grappler config in default eager context.
# - INT8 without calibration behaves like FP32/FP16.
if TRT_VERSION >= (5, 0, 0):
opts = list(
opts = list(
itertools.product([FP32, FP16, INT8], [convert_offline], [dynamic_engine],
[no_calibration]))
# We always run calibration with offline tool.
# TODO(aaroey): INT8+calibration is not supported yet in V2.
# opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
else:
opts = list(
itertools.product([FP32, FP16], [convert_offline], [dynamic_engine],
[no_calibration]))
# We always run calibration with offline tool.
# TODO(aaroey): INT8+calibration is not supported yet in V2.
# opts.append((INT8, convert_offline, dynamic_engine, use_calibration))
return opts


Expand Down
23 changes: 3 additions & 20 deletions tensorflow/python/compiler/tensorrt/trt_convert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import numpy as np

from tensorflow.compiler.tf2tensorrt.wrap_py_utils import get_linked_tensorrt_version
from tensorflow.compiler.tf2tensorrt.wrap_py_utils import is_tensorrt_enabled
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
Expand Down Expand Up @@ -77,22 +76,13 @@ def testGetTensorrtRewriterConfig(self):
"""Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
if not is_tensorrt_enabled():
return
TRT_VERSION = get_linked_tensorrt_version()
if TRT_VERSION >= (5, 0, 0):
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_batch_size=128,
max_workspace_size_bytes=1234,
precision_mode="INT8",
minimum_segment_size=10,
is_dynamic_op=True,
maximum_cached_engines=2)
else:
conversion_params = trt_convert.DEFAULT_TRT_CONVERSION_PARAMS._replace(
max_batch_size=128,
max_workspace_size_bytes=1234,
minimum_segment_size=10,
is_dynamic_op=True,
maximum_cached_engines=2)
rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
conversion_params=conversion_params)
self.assertEqual(["constfold", "layout", "constfold"],
Expand All @@ -115,8 +105,7 @@ def testGetTensorrtRewriterConfig(self):
self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
self.assertEqual(1234,
trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
if TRT_VERSION >= (5, 0, 0):
self.assertEqual(
self.assertEqual(
trt_convert._to_bytes("INT8"),
trt_optimizer.parameter_map["precision_mode"].s)
self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
Expand Down Expand Up @@ -311,13 +300,7 @@ def testTrtGraphConverter_BasicConversion(self):
input_saved_model_dir = self.mkdtemp()
self._WriteInputSavedModel(input_saved_model_dir)

TRT_VERSION = get_linked_tensorrt_version()
if TRT_VERSION >= (5, 0, 0):
calibration_opts = [False, True]
else:
calibration_opts = [False]

for need_calibration in calibration_opts:
for need_calibration in [False, True]:
# Use GraphDef as input.
self._TestTrtGraphConverter()

Expand Down

0 comments on commit 8bf4b13

Please sign in to comment.