Skip to content

Commit 0f4e188

Browse files
codeislife99Ubuntu
authored andcommitted
Dynamic Batch Support for TRT (apache#6955)
* add_annotate_fn * Reshape_ann_fn * Prune Subgraph * Dynamic Shape * Make PT Mask RCNN Work * Cleanup * Remove comments * Remove COmments * GetBatchSizeFix * Fix Remove Droupout * Fix Remove Droupout * TRT Runtime * Add MaskrCNN R50 * New Testing code * Fix black * Test Maskrcnn r50 done * Test MR50 * Space typo * Change Log to Dlog * Move test to tensorrt.py * Remove imports * Remove function * Add it to trt * import error * Imports * Add torch to CI * trt_test * Check test * Revert Pytorch install * Fix * test dynamic batch * TRT * Resolve PR comments * Zero batch size add Co-authored-by: Ubuntu <ubuntu@ip-172-31-27-149.us-east-2.compute.internal>
1 parent 9c4176e commit 0f4e188

File tree

4 files changed

+318
-27
lines changed

4 files changed

+318
-27
lines changed

python/tvm/relay/op/contrib/tensorrt.py

Lines changed: 104 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tvm.relay import transform
2424
from tvm.relay.build_module import bind_params_by_name
2525
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
26-
from tvm.relay.expr_functor import ExprMutator
26+
from tvm.relay.expr_functor import ExprMutator, ExprVisitor
2727

2828
logger = logging.getLogger("TensorRT")
2929

@@ -173,7 +173,7 @@ def check_dynamism(args, op_name):
173173
"""
174174
for arg in args:
175175
if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
176-
for dim_shape in arg.checked_type.shape:
176+
for dim_shape in arg.checked_type.shape[1:]:
177177
if isinstance(dim_shape, tvm.tir.expr.Any):
178178
return True
179179
elif isinstance(arg, Tuple):
@@ -198,6 +198,21 @@ def _func_wrapper(expr):
198198
if any([x.checked_type.dtype != "float32" for x in args]):
199199
logger.info("Only float32 inputs are supported for TensorRT.")
200200
return False
201+
if op_name == "multiply":
202+
shapes = [
203+
[
204+
int(x) if not isinstance(x, tvm.tir.expr.Any) else -1
205+
for x in arg.checked_type.shape
206+
]
207+
for arg in args
208+
]
209+
# Batched multiply operations don't work in implicit batch mode. The following shapes
210+
# have been excluded because they occur in PT MaskRCNN model. The long term solution is
211+
# to switch to explicit batch mode after performance regressions are solved.
212+
if all(
213+
[list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes]
214+
):
215+
return False
201216
return checker(attrs, args, op_name)
202217

203218
return _func_wrapper
@@ -292,19 +307,26 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
292307
"""Check if add is supported by TensorRT."""
293308

294309
args = expr.args
310+
311+
shapes = [
312+
[int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
313+
for arg in args
314+
]
315+
295316
# RelayVM + TRT doesn't support scalar addition yet.
296-
for arg in args:
297-
if not arg.checked_type.shape:
317+
for shape in shapes:
318+
if len(shape) < 1:
298319
return False
320+
299321
if any([x.checked_type.dtype != "float32" for x in args]):
300322
logger.info("Only float32 inputs are supported for TensorRT.")
301323
return False
302324
if (
303325
not get_tensorrt_use_implicit_batch_mode()
304326
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
305-
and args[0].checked_type.shape[0] == args[1].checked_type.shape[0]
306-
and args[0].checked_type.shape[0] != 1
307-
and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3)
327+
and shapes[0][0] == shapes[1][0]
328+
and shapes[0][0] != 1
329+
and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
308330
):
309331
logger.info("add: bug in TRT with adding batched constants.")
310332
return False
@@ -592,11 +614,35 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
592614
logger.info("reshape: new shape dims must be explicit.")
593615
return False
594616
if get_tensorrt_use_implicit_batch_mode():
595-
shape = list(map(int, args[0].checked_type.shape))
596-
new_shape = list(map(int, attrs.newshape))
617+
shape = args[0].checked_type.shape
618+
new_shape = attrs.newshape
597619
if len(new_shape) == 0 or len(shape) == 0:
598620
logger.info("reshape: Can't reshape to or from scalar.")
599621
return False
622+
623+
dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])
624+
625+
if dynamic_reshape:
626+
# Make sure that the batch dim is unmodified.
627+
if int(new_shape[0]) < 0:
628+
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
629+
if not (
630+
isinstance(shape_val, int)
631+
and isinstance(new_shape_val, int)
632+
and int(shape_val) == int(new_shape_val)
633+
):
634+
return False
635+
elif int(new_shape[0]) > 0:
636+
if not (
637+
isinstance(shape[0], int)
638+
and isinstance(new_shape[0], int)
639+
and int(shape[0]) == int(new_shape[0])
640+
):
641+
return False
642+
return True
643+
shape = list(map(int, shape))
644+
new_shape = list(map(int, new_shape))
645+
600646
# TRT cannot modify batch dimension.
601647
original_volume = np.prod(shape)
602648
# First, resolve 0.
@@ -607,6 +653,7 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
607653
for i, value in enumerate(new_shape):
608654
if value == -1:
609655
new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
656+
# Remove batch dimension and see if volumes match
610657
if shape[0] != new_shape[0]:
611658
logger.info("reshape: can't modify batch dimension.")
612659
return False
@@ -795,31 +842,73 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
795842
return True
796843

797844

845+
class IsComputeIntensiveGraph(ExprVisitor):
846+
"""
847+
Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and
848+
its transpose, dense and batch mat-mul.
849+
"""
850+
851+
def __init__(self):
852+
ExprVisitor.__init__(self)
853+
self.is_compute_intensive = False
854+
855+
def visit_call(self, call):
856+
compute_intensive_ops = set(
857+
[
858+
"nn.conv2d",
859+
"nn.conv2d_transpose",
860+
"nn.conv3d",
861+
"nn.conv3d_transpose",
862+
"nn.dense",
863+
"nn.batch_matmul",
864+
]
865+
)
866+
if isinstance(call.op, tvm.tir.op.Op):
867+
if str(call.op) in compute_intensive_ops:
868+
self.is_compute_intensive = True
869+
870+
return super().visit_call(call)
871+
872+
def is_graph_compute_intensive(self, subgraph) -> bool:
873+
"""
874+
This function recursively visits the graph and checks if it's compute intensive"
875+
"""
876+
self.visit(subgraph)
877+
return self.is_compute_intensive
878+
879+
798880
def is_valid_subgraph(params, body):
799881
"""Final check on whether the subgraph is valid and should be offloaded to TensorRT."""
800882
# Remove invalid subgraphs for implicit batch mode.
801883
if get_tensorrt_use_implicit_batch_mode():
802884
input_batch_sizes = []
803885
for var in params:
804886
# In implicit batch mode, all inputs must have same batch size
887+
# TODO: (codeislife99) : Fix different dynamic batch size inputs
888+
805889
if isinstance(var.checked_type, relay.TupleType):
806890
for tupe_type in var.checked_type.fields:
807891
# Scalar inputs not allowed
808892
if len(tupe_type.shape) == 0:
809893
logger.info("tensorrt: scalar inputs not supported")
810894
return False
811-
input_batch_sizes.append(int(tupe_type.shape[0]))
895+
896+
if not isinstance(tupe_type.shape[0], tvm.tir.expr.Any):
897+
input_batch_sizes.append(int(tupe_type.shape[0]))
812898
else:
813899
# Scalar inputs not allowed
814900
if len(var.checked_type.shape) == 0:
815901
logger.info("tensorrt: scalar inputs not supported")
816902
return False
817-
input_batch_sizes.append(int(var.checked_type.shape[0]))
903+
if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any):
904+
input_batch_sizes.append(int(var.checked_type.shape[0]))
818905
if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
819906
logger.info("tensorrt: inputs have different batch sizes")
820907
return False
821-
# Remove subgraphs with no multiply-accumulates
822-
if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0:
908+
if (
909+
get_tensorrt_remove_no_mac_subgraphs()
910+
and not IsComputeIntensiveGraph().is_graph_compute_intensive(body)
911+
):
823912
return False
824913
return True
825914

@@ -880,6 +969,8 @@ class RemoveDropout(ExprMutator):
880969

881970
def visit_tuple_getitem(self, op):
882971
visit = super().visit_tuple_getitem(op)
972+
if visit.index != 0:
973+
return visit
883974
if (
884975
isinstance(visit.tuple_value, Call)
885976
and visit.tuple_value.op.name == "nn.dropout"

src/relay/backend/utils.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,7 @@ inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) {
160160
std::vector<int64_t> ret;
161161
for (const auto& dim : shape) {
162162
const int64_t* pval = tir::as_const_int(dim);
163-
ICHECK(pval) << "Expect integer, but received: " << dim->GetTypeKey();
164-
ret.push_back(*pval);
163+
ret.push_back(pval ? *pval : -1);
165164
}
166165
return ret;
167166
}

src/runtime/contrib/tensorrt/tensorrt_runtime.cc

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ namespace tvm {
4141
namespace runtime {
4242
namespace contrib {
4343

44+
struct PairHash {
45+
template <class T1, class T2>
46+
std::size_t operator()(const std::pair<T1, T2>& pair) const {
47+
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
48+
}
49+
};
50+
4451
using namespace tvm::runtime::json;
4552

4653
class TensorRTRuntime : public JSONRuntimeBase {
@@ -105,12 +112,13 @@ class TensorRTRuntime : public JSONRuntimeBase {
105112
/*! \brief Run inference using built engine. */
106113
void Run() override {
107114
BuildEngine();
108-
auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
115+
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
116+
if (batch_size_ == 0) return;
117+
auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_));
109118
auto engine = engine_and_context.engine;
110119
auto context = engine_and_context.context;
111120
auto& device_buffers = engine_and_context.device_buffers;
112121
std::vector<void*> bindings(engine->getNbBindings(), nullptr);
113-
114122
for (size_t i = 0; i < input_nodes_.size(); ++i) {
115123
auto nid = input_nodes_[i];
116124
if (nodes_[nid].GetOpType() == "input") {
@@ -169,10 +177,11 @@ class TensorRTRuntime : public JSONRuntimeBase {
169177
* do nothing.
170178
*/
171179
void BuildEngine() {
172-
if (trt_engine_cache_.count(symbol_name_)) return;
173-
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
180+
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
181+
if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
182+
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
183+
<< " with batch size " << batch_size_;
174184
const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
175-
batch_size_ = GetBatchSize();
176185
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_,
177186
use_fp16, batch_size_);
178187

@@ -203,8 +212,9 @@ class TensorRTRuntime : public JSONRuntimeBase {
203212
}
204213

205214
// Build engine.
206-
trt_engine_cache_[symbol_name_] = builder.BuildEngine();
207-
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
215+
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine();
216+
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_
217+
<< " with batch size " << batch_size_;
208218
CacheEngineToDisk();
209219
}
210220

@@ -240,30 +250,35 @@ class TensorRTRuntime : public JSONRuntimeBase {
240250
helper.DeclareField("inputs", &engine_and_context.inputs);
241251
helper.DeclareField("outputs", &engine_and_context.outputs);
242252
helper.ReadAllFields(&reader);
243-
trt_engine_cache_[symbol_name_] = engine_and_context;
253+
const int batch_size = 1;
254+
trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context;
244255
return true;
245256
}
246257

247258
/*! \brief If TVM_TENSORRT_CACHE_DIR is set, will save the engine to that
248259
* directory so it can be loaded later.
249260
*/
250261
void CacheEngineToDisk() {
262+
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
251263
std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string(""));
252264
if (cache_dir.empty()) return;
253265
std::string key = GetSubgraphKey();
254266
std::string path = cache_dir + "/" + key + ".plan";
255267
DLOG(INFO) << "Caching TensorRT engine to " << path;
256268
// Serialize engine to disk
257-
nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize();
269+
nvinfer1::IHostMemory* serialized_engine =
270+
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize();
258271
SaveBinaryToFile(path, std::string(static_cast<const char*>(serialized_engine->data()),
259272
serialized_engine->size()));
260273
serialized_engine->destroy();
261274
// Serialize metadata
262275
std::ostringstream os;
263276
dmlc::JSONWriter writer(&os);
264277
writer.BeginObject();
265-
writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs);
266-
writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs);
278+
writer.WriteObjectKeyValue("inputs",
279+
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs);
280+
writer.WriteObjectKeyValue(
281+
"outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs);
267282
writer.EndObject();
268283
std::string meta_path = cache_dir + "/" + key + ".meta";
269284
SaveBinaryToFile(meta_path, os.str());
@@ -290,7 +305,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
290305
}
291306

292307
/*! \brief Map of function name to TRT engine if built already. */
293-
std::unordered_map<std::string, TensorRTEngineAndContext> trt_engine_cache_;
308+
std::unordered_map<std::pair<std::string, int>, TensorRTEngineAndContext, PairHash>
309+
trt_engine_cache_;
294310

295311
/*! \brief TensorRT logger. */
296312
TensorRTLogger logger_;

0 commit comments

Comments
 (0)