Skip to content

Commit

Permalink
[Relay/TRT] Support clip for TRT 4 using relu + eltwise (#83)
Browse files Browse the repository at this point in the history
* Support clip for TRT 4 using relu + eltwise

* Re-enable consistency check

* Invoke convertlayout properly
  • Loading branch information
trevor-m authored Feb 4, 2020
1 parent 3188a68 commit 7475796
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 6 deletions.
7 changes: 4 additions & 3 deletions python/tvm/relay/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,10 @@ def EnableTrt(mod, params=None, trt_version=None):
assert len(trt_version) == 3

# Apply passes required for TRT
mod = relay.transform.RemoveUnusedFunctions()(mod)
mod = relay.transform.InferType()(mod)
mod = relay.transform.ConvertLayout('NCHW')(mod)
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
mod = PreprocessForTrt(mod)
if params:
# Bind params so that we can use FoldConstant.
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/contrib/tensorrt/enable_tensorrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,8 @@ static const std::unordered_map<std::string, IsCompatibleFn>
{"mean", ReduceOpChecker},
{"contrib.adaptive_max_pool2d", AdapativePool2DOpChecker},
{"contrib.adaptive_avg_pool2d", AdapativePool2DOpChecker},
{"clip", AlwaysChecker},
// Ops which require TRT 5.1.5+
{"clip", TrtVersionChecker<5, 1, 5>},
{"nn.leaky_relu", TrtVersionChecker<5, 1, 5>},
{"sin", TrtVersionChecker<5, 1, 5>},
{"cos", TrtVersionChecker<5, 1, 5>},
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ GetOpConverters() {
map->emplace("ceil", std::make_shared<UnaryOpConverter>());
map->emplace("floor", std::make_shared<UnaryOpConverter>());
map->emplace("strided_slice", std::make_shared<StridedSliceOpConverter>());
#else
map->emplace("clip", std::make_shared<ClipLegacyOpConverter>());
#endif
#if TRT_VERSION_GE(6, 0, 1)
map->emplace("image.resize", std::make_shared<ResizeOpConverter>());
Expand Down
56 changes: 56 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ class TrtOpConverter {
// Subtract 1 for implicit batch dim.
return axis - 1;
}

// Create constant that is broadcastable against input.
/*!
* \brief Create constant that is broadcastable.
* \param params Parameters for this op.
* \param value Value of scalar.
* \param broadcast_to_dims Dims that scalar should be broadcastable against.
* \return Constant tensor.
*/
nvinfer1::ITensor* CreateScalar(
AddTrtLayerParams* params, float value,
const nvinfer1::Dims& broadcast_to_dims) const {
nvinfer1::Dims dims;
dims.nbDims = broadcast_to_dims.nbDims;
std::fill_n(dims.d, dims.nbDims, 1);
float* values = new float[1];
values[0] = value;
nvinfer1::Weights weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(values), 1};
params->trt_weights->push_back(weights);
return params->network->addConstant(dims, weights)->getOutput(0);
}
};

class ActivationOpConverter : public TrtOpConverter {
Expand Down Expand Up @@ -185,6 +207,40 @@ class ActivationOpConverter : public TrtOpConverter {
}
};

class ClipLegacyOpConverter : public TrtOpConverter {
public:
ClipLegacyOpConverter() : TrtOpConverter({kTensor}) {}

void Convert(AddTrtLayerParams* params) const {
const auto* attrs = params->call->attrs.as<ClipAttrs>();
CHECK_EQ(params->inputs.size(), 1) << "Activation op expects 1 input.";
auto input = params->inputs.at(0).tensor;
// relu(x)
nvinfer1::ITensor* output = nullptr;
if (attrs->a_min == 0.0f) {
// Use relu instead of max(x, 0) because relu can be fused.
nvinfer1::IActivationLayer* relu_layer = params->network->addActivation(
*input, nvinfer1::ActivationType::kRELU);
CHECK(relu_layer != nullptr);
output = relu_layer->getOutput(0);
} else {
// max(x, a_min)
nvinfer1::ITensor* a_min =
CreateScalar(params, attrs->a_min, input->getDimensions());
nvinfer1::IElementWiseLayer* max_layer = params->network->addElementWise(
*input, *a_min, nvinfer1::ElementWiseOperation::kMAX);
CHECK(max_layer != nullptr);
output = max_layer->getOutput(0);
}
// min(relu(x), a_max)
nvinfer1::ITensor* a_max =
CreateScalar(params, attrs->a_max, input->getDimensions());
nvinfer1::IElementWiseLayer* min_layer = params->network->addElementWise(
*output, *a_max, nvinfer1::ElementWiseOperation::kMIN);
params->outputs.push_back(min_layer->getOutput(0));
}
};

class ElementWiseBinaryOpConverter : public TrtOpConverter {
public:
ElementWiseBinaryOpConverter() : TrtOpConverter({kTensor, kTensor}) {}
Expand Down
4 changes: 2 additions & 2 deletions tests/python/relay/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,8 @@ def check_trt_used(graph):
i_data = np.random.uniform(0, 1, input_shape).astype(dtype)
for model in models:
latency[model], res = test_model(model, i_data, input_shape, dtype, use_trt=True)
# _, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1)
# tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3)
_, ref_res = test_model(model, i_data, input_shape, dtype, use_trt=False, num_iteration=1)
tvm.testing.assert_allclose(res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-3)

for model in models:
print(model, latency[model])
Expand Down

0 comments on commit 7475796

Please sign in to comment.