Skip to content

Commit 7a018c5

Browse files
authored
【Hackathon 9th No.99】[Feature Enhancement]Add tensorrt backend for compiler (#237)
* add tensorrt for compiler * fix bug * fix code * Update test_compiler.py
1 parent ae2af30 commit 7a018c5

File tree

1 file changed

+68
-25
lines changed

1 file changed

+68
-25
lines changed

graph_net/torch/test_compiler.py

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,56 @@
1313
from contextlib import contextmanager
1414
import time
1515

16+
try:
17+
import torch_tensorrt
18+
except ImportError:
19+
torch_tensorrt = None
1620

17-
def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Module]:
21+
22+
class GraphCompilerBackend:
23+
def __call__(self, model):
24+
raise NotImplementedError()
25+
26+
def synchronize(self):
27+
raise NotImplementedError()
28+
29+
30+
class InductorBackend(GraphCompilerBackend):
31+
def __call__(self, model):
32+
return torch.compile(model, backend="inductor")
33+
34+
def synchronize(self):
35+
torch.cuda.synchronize()
36+
37+
38+
class TensorRTBackend(GraphCompilerBackend):
39+
def __call__(self, model):
40+
return torch.compile(model, backend="tensorrt")
41+
42+
def synchronize(self):
43+
torch.cuda.synchronize()
44+
45+
46+
registry_backend = {
47+
"inductor": InductorBackend(),
48+
"tensorrt": TensorRTBackend(),
49+
"default": InductorBackend(),
50+
}
51+
52+
53+
def load_class_from_file(
54+
args: argparse.Namespace, class_name: str
55+
) -> Type[torch.nn.Module]:
56+
file_path = f"{args.model_path}/model.py"
1857
file = Path(file_path).resolve()
1958
module_name = file.stem
2059

2160
with open(file_path, "r", encoding="utf-8") as f:
2261
original_code = f.read()
23-
import_stmt = "import torch"
24-
modified_code = f"{import_stmt}\n{original_code}"
62+
if args.device == "cuda":
63+
modified_code = original_code.replace("cpu", "cuda")
64+
else:
65+
modified_code = original_code
2566
spec = importlib.util.spec_from_loader(module_name, loader=None)
2667
module = importlib.util.module_from_spec(spec)
2768
sys.modules[module_name] = module
@@ -32,27 +73,23 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
3273
return model_class
3374

3475

35-
def get_compiler(args):
36-
assert args.compiler == "default"
37-
return torch.compile
38-
39-
40-
def get_synchronizer_func(args):
41-
assert args.compiler == "default"
42-
return torch.cuda.synchronize
76+
def get_compiler_backend(args) -> GraphCompilerBackend:
77+
assert args.compiler in registry_backend, f"Unknown compiler: {args.compiler}"
78+
return registry_backend[args.compiler]
4379

4480

4581
def get_model(args):
46-
model_class = load_class_from_file(
47-
f"{args.model_path}/model.py", class_name="GraphModule"
48-
)
49-
return model_class()
82+
model_class = load_class_from_file(args, class_name="GraphModule")
83+
return model_class().to(torch.device(args.device))
5084

5185

5286
def get_input_dict(args):
5387
inputs_params = utils.load_converted_from_text(f"{args.model_path}")
5488
params = inputs_params["weight_info"]
55-
return {k: utils.replay_tensor(v) for k, v in params.items()}
89+
return {
90+
k: utils.replay_tensor(v).to(torch.device(args.device))
91+
for k, v in params.items()
92+
}
5693

5794

5895
@dataclass
@@ -71,15 +108,14 @@ def naive_timer(duration_box, get_synchronizer_func):
71108

72109

73110
def test_single_model(args):
74-
compiler = get_compiler(args)
75-
synchronizer_func = get_synchronizer_func(args)
111+
compiler = get_compiler_backend(args)
76112
input_dict = get_input_dict(args)
77113
model = get_model(args)
78114
compiled_model = compiler(model)
79115

80116
# eager
81117
eager_duration_box = DurationBox(-1)
82-
with naive_timer(eager_duration_box, synchronizer_func):
118+
with naive_timer(eager_duration_box, compiler.synchronize):
83119
expected_out = model(**input_dict)
84120

85121
# warmup
@@ -88,7 +124,7 @@ def test_single_model(args):
88124

89125
# compiled
90126
compiled_duration_box = DurationBox(-1)
91-
with naive_timer(compiled_duration_box, synchronizer_func):
127+
with naive_timer(compiled_duration_box, compiler.synchronize):
92128
compiled_out = compiled_model(**input_dict)
93129

94130
def print_cmp(key, func, **kwargs):
@@ -157,11 +193,11 @@ def test_multi_models(args):
157193
cmd = "".join(
158194
[
159195
sys.executable,
160-
"-m graph_net.torch.test_compiler",
161-
f"--model-path {model_path}",
162-
f"--compiler {args.compiler}",
163-
f"--warmup {args.warmup}",
164-
f"--log-prompt {args.log_prompt}",
196+
" -m graph_net.torch.test_compiler",
197+
f" --model-path {model_path}",
198+
f" --compiler {args.compiler}",
199+
f" --warmup {args.warmup}",
200+
f" --log-prompt {args.log_prompt}",
165201
]
166202
)
167203
cmd_ret = os.system(cmd)
@@ -212,6 +248,13 @@ def main(args):
212248
default="default",
213249
help="Path to customized compiler python file",
214250
)
251+
parser.add_argument(
252+
"--device",
253+
type=str,
254+
required=False,
255+
default="cpu",
256+
help="Device for testing the compiler",
257+
)
215258
parser.add_argument(
216259
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
217260
)

0 commit comments

Comments
 (0)