13
13
from contextlib import contextmanager
14
14
import time
15
15
16
+ try :
17
+ import torch_tensorrt
18
+ except ImportError :
19
+ torch_tensorrt = None
16
20
17
- def load_class_from_file (file_path : str , class_name : str ) -> Type [torch .nn .Module ]:
21
+
22
+ class GraphCompilerBackend :
23
+ def __call__ (self , model ):
24
+ raise NotImplementedError ()
25
+
26
+ def synchronize (self ):
27
+ raise NotImplementedError ()
28
+
29
+
30
+ class InductorBackend (GraphCompilerBackend ):
31
+ def __call__ (self , model ):
32
+ return torch .compile (model , backend = "inductor" )
33
+
34
+ def synchronize (self ):
35
+ torch .cuda .synchronize ()
36
+
37
+
38
+ class TensorRTBackend (GraphCompilerBackend ):
39
+ def __call__ (self , model ):
40
+ return torch .compile (model , backend = "tensorrt" )
41
+
42
+ def synchronize (self ):
43
+ torch .cuda .synchronize ()
44
+
45
+
46
+ registry_backend = {
47
+ "inductor" : InductorBackend (),
48
+ "tensorrt" : TensorRTBackend (),
49
+ "default" : InductorBackend (),
50
+ }
51
+
52
+
53
+ def load_class_from_file (
54
+ args : argparse .Namespace , class_name : str
55
+ ) -> Type [torch .nn .Module ]:
56
+ file_path = f"{ args .model_path } /model.py"
18
57
file = Path (file_path ).resolve ()
19
58
module_name = file .stem
20
59
21
60
with open (file_path , "r" , encoding = "utf-8" ) as f :
22
61
original_code = f .read ()
23
- import_stmt = "import torch"
24
- modified_code = f"{ import_stmt } \n { original_code } "
62
+ if args .device == "cuda" :
63
+ modified_code = original_code .replace ("cpu" , "cuda" )
64
+ else :
65
+ modified_code = original_code
25
66
spec = importlib .util .spec_from_loader (module_name , loader = None )
26
67
module = importlib .util .module_from_spec (spec )
27
68
sys .modules [module_name ] = module
@@ -32,27 +73,23 @@ def load_class_from_file(file_path: str, class_name: str) -> Type[torch.nn.Modul
32
73
return model_class
33
74
34
75
35
- def get_compiler (args ):
36
- assert args .compiler == "default"
37
- return torch .compile
38
-
39
-
40
- def get_synchronizer_func (args ):
41
- assert args .compiler == "default"
42
- return torch .cuda .synchronize
76
+ def get_compiler_backend (args ) -> GraphCompilerBackend :
77
+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
78
+ return registry_backend [args .compiler ]
43
79
44
80
45
81
def get_model (args ):
46
- model_class = load_class_from_file (
47
- f"{ args .model_path } /model.py" , class_name = "GraphModule"
48
- )
49
- return model_class ()
82
+ model_class = load_class_from_file (args , class_name = "GraphModule" )
83
+ return model_class ().to (torch .device (args .device ))
50
84
51
85
52
86
def get_input_dict (args ):
53
87
inputs_params = utils .load_converted_from_text (f"{ args .model_path } " )
54
88
params = inputs_params ["weight_info" ]
55
- return {k : utils .replay_tensor (v ) for k , v in params .items ()}
89
+ return {
90
+ k : utils .replay_tensor (v ).to (torch .device (args .device ))
91
+ for k , v in params .items ()
92
+ }
56
93
57
94
58
95
@dataclass
@@ -71,15 +108,14 @@ def naive_timer(duration_box, get_synchronizer_func):
71
108
72
109
73
110
def test_single_model (args ):
74
- compiler = get_compiler (args )
75
- synchronizer_func = get_synchronizer_func (args )
111
+ compiler = get_compiler_backend (args )
76
112
input_dict = get_input_dict (args )
77
113
model = get_model (args )
78
114
compiled_model = compiler (model )
79
115
80
116
# eager
81
117
eager_duration_box = DurationBox (- 1 )
82
- with naive_timer (eager_duration_box , synchronizer_func ):
118
+ with naive_timer (eager_duration_box , compiler . synchronize ):
83
119
expected_out = model (** input_dict )
84
120
85
121
# warmup
@@ -88,7 +124,7 @@ def test_single_model(args):
88
124
89
125
# compiled
90
126
compiled_duration_box = DurationBox (- 1 )
91
- with naive_timer (compiled_duration_box , synchronizer_func ):
127
+ with naive_timer (compiled_duration_box , compiler . synchronize ):
92
128
compiled_out = compiled_model (** input_dict )
93
129
94
130
def print_cmp (key , func , ** kwargs ):
@@ -157,11 +193,11 @@ def test_multi_models(args):
157
193
cmd = "" .join (
158
194
[
159
195
sys .executable ,
160
- "-m graph_net.torch.test_compiler" ,
161
- f"--model-path { model_path } " ,
162
- f"--compiler { args .compiler } " ,
163
- f"--warmup { args .warmup } " ,
164
- f"--log-prompt { args .log_prompt } " ,
196
+ " -m graph_net.torch.test_compiler" ,
197
+ f" --model-path { model_path } " ,
198
+ f" --compiler { args .compiler } " ,
199
+ f" --warmup { args .warmup } " ,
200
+ f" --log-prompt { args .log_prompt } " ,
165
201
]
166
202
)
167
203
cmd_ret = os .system (cmd )
@@ -212,6 +248,13 @@ def main(args):
212
248
default = "default" ,
213
249
help = "Path to customized compiler python file" ,
214
250
)
251
+ parser .add_argument (
252
+ "--device" ,
253
+ type = str ,
254
+ required = False ,
255
+ default = "cpu" ,
256
+ help = "Device for testing the compiler" ,
257
+ )
215
258
parser .add_argument (
216
259
"--warmup" , type = int , required = False , default = 5 , help = "Number of warmup steps"
217
260
)
0 commit comments