Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Batch Support for TRT #6955

Merged
merged 34 commits into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from tvm.relay.expr import Call, Constant, Tuple, GlobalVar, Var, TupleGetItem
from tvm.relay.expr_functor import ExprMutator
from tvm.relay.expr_functor import ExprMutator, ExprVisitor

logger = logging.getLogger("TensorRT")

Expand Down Expand Up @@ -173,7 +173,7 @@ def check_dynamism(args, op_name):
"""
for arg in args:
if isinstance(arg, (Call, Var, Constant, TupleGetItem)):
for dim_shape in arg.checked_type.shape:
for dim_shape in arg.checked_type.shape[1:]:
if isinstance(dim_shape, tvm.tir.expr.Any):
return True
elif isinstance(arg, Tuple):
Expand All @@ -198,6 +198,18 @@ def _func_wrapper(expr):
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if op_name == "multiply":
shapes = [
[
int(x) if not isinstance(x, tvm.tir.expr.Any) else -1
for x in arg.checked_type.shape
]
for arg in args
]
if all(
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
[list(map(int, shape)) in [[300, 64, 7, 7], [300, 1, 1, 1]] for shape in shapes]
):
return False
return checker(attrs, args, op_name)

return _func_wrapper
Expand Down Expand Up @@ -292,19 +304,26 @@ def add_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if add is supported by TensorRT."""

args = expr.args

shapes = [
[int(x) if not isinstance(x, tvm.tir.expr.Any) else -1 for x in arg.checked_type.shape]
for arg in args
]

# RelayVM + TRT doesn't support scalar addition yet.
for arg in args:
if not arg.checked_type.shape:
for shape in shapes:
if len(shape) < 1:
return False

if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if (
not get_tensorrt_use_implicit_batch_mode()
and (isinstance(args[0], Constant) or isinstance(args[1], Constant))
and args[0].checked_type.shape[0] == args[1].checked_type.shape[0]
and args[0].checked_type.shape[0] != 1
and (len(args[0].checked_type.shape) > 3 or len(args[1].checked_type.shape) > 3)
and shapes[0][0] == shapes[1][0]
and shapes[0][0] != 1
and (len(shapes[0]) > 3 or len(shapes[1]) > 3)
):
logger.info("add: bug in TRT with adding batched constants.")
return False
Expand Down Expand Up @@ -592,11 +611,35 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
logger.info("reshape: new shape dims must be explicit.")
return False
if get_tensorrt_use_implicit_batch_mode():
shape = list(map(int, args[0].checked_type.shape))
new_shape = list(map(int, attrs.newshape))
shape = args[0].checked_type.shape
new_shape = attrs.newshape
if len(new_shape) == 0 or len(shape) == 0:
logger.info("reshape: Can't reshape to or from scalar.")
return False

dynamic_reshape = any([isinstance(x, tvm.tir.expr.Any) for x in shape])

if dynamic_reshape:
# Make sure that the batch dim is unmodified.
if int(new_shape[0]) < 0:
for shape_val, new_shape_val in enumerate(shape[1:], new_shape[1:]):
if not (
isinstance(shape_val, int)
and isinstance(new_shape_val, int)
and int(shape_val) == int(new_shape_val)
):
return False
elif int(new_shape[0]) > 0:
if not (
isinstance(shape[0], int)
and isinstance(new_shape[0], int)
and int(shape[0]) == int(new_shape[0])
):
return False
return True
shape = list(map(int, shape))
new_shape = list(map(int, new_shape))

# TRT cannot modify batch dimension.
original_volume = np.prod(shape)
# First, resolve 0.
Expand All @@ -607,8 +650,9 @@ def reshape_annotate_fn(expr): # pylint: disable=unused-variable
for i, value in enumerate(new_shape):
if value == -1:
new_shape[i] = original_volume // np.prod([x for x in new_shape if x != -1])
# Remove batch dimension and see if volumes match
if shape[0] != new_shape[0]:
logger.info("reshape: can't modify batch dimension.")
print("reshape: can't modify batch dimension.")
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
return False
return True

Expand Down Expand Up @@ -795,6 +839,38 @@ def conv3d_transpose_annotate_fn(expr): # pylint: disable=unused-variable
return True


class IsComputeIntensiveGraph(ExprVisitor):
"""
Visits the Graph recursively and checks if it contains compute heavy ops like convolutions and
its transpose, dense and batch mat-mul.
"""

def __init__(self):
ExprVisitor.__init__(self)
self.is_compute_intensive = False

def visit_call(self, call):
heavy_ops = set(
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
[
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv3d",
"nn.conv3d_transpose",
"nn.dense",
"nn.batch_matmul",
]
)
if isinstance(call.op, tvm.tir.op.Op):
if str(call.op) in heavy_ops:
self.is_compute_intensive = True

return super().visit_call(call)

def is_graph_compute_intensive(self, subgraph):
self.visit(subgraph)
return self.is_compute_intensive


def is_valid_subgraph(params, body):
"""Final check on whether the subgraph is valid and should be offloaded to TensorRT."""
# Remove invalid subgraphs for implicit batch mode.
Expand All @@ -808,18 +884,22 @@ def is_valid_subgraph(params, body):
if len(tupe_type.shape) == 0:
logger.info("tensorrt: scalar inputs not supported")
return False
input_batch_sizes.append(int(tupe_type.shape[0]))
if not isinstance(tupe_type.shape[0], tvm.tir.expr.Any):
input_batch_sizes.append(int(tupe_type.shape[0]))
else:
# Scalar inputs not allowed
if len(var.checked_type.shape) == 0:
logger.info("tensorrt: scalar inputs not supported")
return False
input_batch_sizes.append(int(var.checked_type.shape[0]))
if not isinstance(var.checked_type.shape[0], tvm.tir.expr.Any):
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
input_batch_sizes.append(int(var.checked_type.shape[0]))
if len(input_batch_sizes) > 1 and len(set(input_batch_sizes)) != 1:
logger.info("tensorrt: inputs have different batch sizes")
return False
# Remove subgraphs with no multiply-accumulates
if get_tensorrt_remove_no_mac_subgraphs() and relay.analysis.get_total_mac_number(body) == 0:
if (
get_tensorrt_remove_no_mac_subgraphs()
and not IsComputeIntensiveGraph().is_graph_compute_intensive(body)
):
return False
return True

Expand Down Expand Up @@ -880,6 +960,8 @@ class RemoveDropout(ExprMutator):

def visit_tuple_getitem(self, op):
visit = super().visit_tuple_getitem(op)
if visit.index != 0:
return visit
if (
isinstance(visit.tuple_value, Call)
and visit.tuple_value.op.name == "nn.dropout"
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) {
std::vector<int64_t> ret;
for (const auto& dim : shape) {
const int64_t* pval = tir::as_const_int(dim);
ICHECK(pval) << "Expect integer, but received: " << dim->GetTypeKey();
ret.push_back(*pval);
ret.push_back(pval ? *pval : -1);
}
return ret;
}
Expand Down
39 changes: 27 additions & 12 deletions src/runtime/contrib/tensorrt/tensorrt_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ namespace tvm {
namespace runtime {
namespace contrib {

struct PairHash {
template <class T1, class T2>
std::size_t operator()(const std::pair<T1, T2>& pair) const {
return std::hash<T1>()(pair.first) ^ std::hash<T2>()(pair.second);
}
};

using namespace tvm::runtime::json;

class TensorRTRuntime : public JSONRuntimeBase {
Expand Down Expand Up @@ -105,12 +112,12 @@ class TensorRTRuntime : public JSONRuntimeBase {
/*! \brief Run inference using built engine. */
void Run() override {
BuildEngine();
auto& engine_and_context = trt_engine_cache_.at(symbol_name_);
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
auto& engine_and_context = trt_engine_cache_.at(std::make_pair(symbol_name_, batch_size_));
auto engine = engine_and_context.engine;
auto context = engine_and_context.context;
auto& device_buffers = engine_and_context.device_buffers;
std::vector<void*> bindings(engine->getNbBindings(), nullptr);

for (size_t i = 0; i < input_nodes_.size(); ++i) {
auto nid = input_nodes_[i];
if (nodes_[nid].GetOpType() == "input") {
Expand Down Expand Up @@ -169,10 +176,11 @@ class TensorRTRuntime : public JSONRuntimeBase {
* do nothing.
*/
void BuildEngine() {
if (trt_engine_cache_.count(symbol_name_)) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_;
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
if (trt_engine_cache_.count(std::make_pair(symbol_name_, batch_size_))) return;
DLOG(INFO) << "Building new TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
const bool use_fp16 = dmlc::GetEnv("TVM_TENSORRT_USE_FP16", false);
batch_size_ = GetBatchSize();
TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, use_implicit_batch_,
use_fp16, batch_size_);

Expand Down Expand Up @@ -203,8 +211,9 @@ class TensorRTRuntime : public JSONRuntimeBase {
}

// Build engine.
trt_engine_cache_[symbol_name_] = builder.BuildEngine();
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_;
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)] = builder.BuildEngine();
DLOG(INFO) << "Finished building TensorRT engine for subgraph " << symbol_name_
<< " with batch size " << batch_size_;
CacheEngineToDisk();
}

Expand Down Expand Up @@ -240,30 +249,35 @@ class TensorRTRuntime : public JSONRuntimeBase {
helper.DeclareField("inputs", &engine_and_context.inputs);
helper.DeclareField("outputs", &engine_and_context.outputs);
helper.ReadAllFields(&reader);
trt_engine_cache_[symbol_name_] = engine_and_context;
const int batch_size = 1;
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
trt_engine_cache_[std::make_pair(symbol_name_, batch_size)] = engine_and_context;
return true;
}

/*! \brief If TVM_TENSORRT_CACHE_DIR is set, will save the engine to that
* directory so it can be loaded later.
*/
void CacheEngineToDisk() {
batch_size_ = data_entry_[input_var_eid_[0]]->shape[0];
std::string cache_dir = dmlc::GetEnv("TVM_TENSORRT_CACHE_DIR", std::string(""));
if (cache_dir.empty()) return;
std::string key = GetSubgraphKey();
std::string path = cache_dir + "/" + key + ".plan";
DLOG(INFO) << "Caching TensorRT engine to " << path;
// Serialize engine to disk
nvinfer1::IHostMemory* serialized_engine = trt_engine_cache_[symbol_name_].engine->serialize();
nvinfer1::IHostMemory* serialized_engine =
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].engine->serialize();
SaveBinaryToFile(path, std::string(static_cast<const char*>(serialized_engine->data()),
serialized_engine->size()));
serialized_engine->destroy();
// Serialize metadata
std::ostringstream os;
dmlc::JSONWriter writer(&os);
writer.BeginObject();
writer.WriteObjectKeyValue("inputs", trt_engine_cache_[symbol_name_].inputs);
writer.WriteObjectKeyValue("outputs", trt_engine_cache_[symbol_name_].outputs);
writer.WriteObjectKeyValue("inputs",
trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].inputs);
writer.WriteObjectKeyValue(
"outputs", trt_engine_cache_[std::make_pair(symbol_name_, batch_size_)].outputs);
writer.EndObject();
std::string meta_path = cache_dir + "/" + key + ".meta";
SaveBinaryToFile(meta_path, os.str());
Expand All @@ -290,7 +304,8 @@ class TensorRTRuntime : public JSONRuntimeBase {
}

/*! \brief Map of function name to TRT engine if built already. */
std::unordered_map<std::string, TensorRTEngineAndContext> trt_engine_cache_;
std::unordered_map<std::pair<std::string, int>, TensorRTEngineAndContext, PairHash>
trt_engine_cache_;

/*! \brief TensorRT logger. */
TensorRTLogger logger_;
Expand Down
Loading