Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support exporting Torch-TRT compiled Graphmodules #3262

Open
wants to merge 12 commits into
base: lluo/save_remove_inputs
Choose a base branch
from
17 changes: 17 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,23 @@ void TRTEngine::verify_serialization_fmt(const std::vector<std::string>& seriali
<< ")");
}

FlattenedState TRTEngine::__obj_flatten__() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure why this needs to be a method vs reusing the functions in register_jit_hooks

// Serialize TensorRT engine
auto serialized_trt_engine = make_trt(this->cuda_engine->serialize());
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

return std::tuple(
std::tuple("version", ABI_VERSION),
std::tuple("name", this->name),
std::tuple("device_info", this->device_info.serialize()),
std::tuple("serialized_engine", base64_encode(trt_engine)),
std::tuple("in_binding_names", this->in_binding_names),
std::tuple("out_binding_names", this->out_binding_names),
std::tuple("hardware_compatible", this->hardware_compatible),
std::tuple("serialized_metadata", this->serialized_metadata),
std::tuple("target_platform", this->target_platform.serialize()));
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
14 changes: 14 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ namespace torch_tensorrt {
namespace core {
namespace runtime {

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::vector<std::string>>, // input binding names
std::tuple<std::string, std::vector<std::string>>, // output binding names
std::tuple<std::string, bool>, // HW compatibility
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
std::shared_ptr<nvinfer1::IRuntime> rt;
Expand Down Expand Up @@ -74,6 +85,9 @@ struct TRTEngine : torch::CustomClassHolder {
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

// Serde re-export functionality
FlattenedState __obj_flatten__();

// CUDAGraph-Related Functionality
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
Expand Down
3 changes: 2 additions & 1 deletion core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
namespace torch_tensorrt {
namespace core {
namespace runtime {
namespace {

std::string serialize_bindings(const std::vector<std::string>& bindings) {
std::stringstream ss;
Expand Down Expand Up @@ -66,6 +65,7 @@ std::string base64_decode(const std::string& in) {
return out;
}

namespace {
// TODO: Implement a call method
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
// auto input_vec = inputs.vec();
Expand All @@ -80,6 +80,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
// TODO: .def("run", &TRTEngine::Run)
.def("__str__", &TRTEngine::to_str)
.def("__repr__", &TRTEngine::to_str)
.def("__obj_flatten__", &TRTEngine::__obj_flatten__)
.def("enable_profiling", &TRTEngine::enable_profiling)
.def("disable_profiling", &TRTEngine::disable_profiling)
.def_readwrite("profile_path_prefix", &TRTEngine::profile_path_prefix)
Expand Down
3 changes: 3 additions & 0 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ typedef enum {
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
} SerializedInfoIndex;

std::string base64_encode(const std::string& in);
std::string base64_decode(const std::string& in);

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
const RTDevice& curr_device = RTDevice(),
Expand Down
12 changes: 6 additions & 6 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,10 +528,10 @@ def save(
exp_program = export(module, arg_inputs, kwarg_inputs)
torch.export.save(exp_program, file_path)
else:
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
# from torch._higher_order_ops.torchbind import enable_torchbind_tracing

with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
# with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
torch.export.save(exp_program, file_path)
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def setup_engine(self) -> None:
"""
if self.engine is not None:
return

self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())

@staticmethod
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
dumped_metadata = pickle.dumps(metadata)
Expand Down Expand Up @@ -270,7 +272,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(input_tensors), self.engine
)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
TorchTensorRTModule,
)
from torch_tensorrt.dynamo.runtime.register_fake_class import FakeTRTEngine
27 changes: 27 additions & 0 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import base64
from typing import Any

import torch


@torch.library.register_fake("tensorrt::execute_engine")
def fake_execute_engine(inputs, trt_engine):
breakpoint()
return trt_engine(inputs)


# namespace::class_name
@torch._library.register_fake_class("tensorrt::Engine")
class FakeTRTEngine:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we move the TorchTensorRTModule impl here?

def __init__(self) -> None:
pass

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
engine_info = [info[1] for info in flattened_tq]
engine_info[3] = base64.b64decode(engine_info[3]) # decode engine
engine_info[4] = str(engine_info[4][0]) # input names
engine_info[5] = str(engine_info[5][0]) # output names
engine_info[6] = str(int(engine_info[6])) # hw compatible
trt_engine = torch.classes.tensorrt.Engine(engine_info)
return trt_engine
Loading