@@ -310,21 +310,16 @@ def forward(self, x, y, z, a, b):
310310 {0 : seq_len3 },
311311 ),
312312 )
313- with torchtrt .dynamo .Debugger (
314- log_level = "debug" ,
315- capture_fx_graph_after = ["remove_num_users_is_0_nodes" ],
316- logging_dir = "/home/profile/logging/moe" ,
317- engine_builder_monitor = False ,
318- ):
319- trt_mod = torchtrt .dynamo .compile (
320- ep ,
321- inputs ,
322- enabled_precisions = {torch .float16 },
323- min_block_size = 1 ,
324- use_explicit_typing = False ,
325- use_fp32_acc = False ,
326- disable_tf32 = True ,
327- )
313+
314+ trt_mod = torchtrt .dynamo .compile (
315+ ep ,
316+ inputs ,
317+ enabled_precisions = {torch .float16 },
318+ min_block_size = 1 ,
319+ use_explicit_typing = False ,
320+ use_fp32_acc = False ,
321+ disable_tf32 = True ,
322+ )
328323 result = trt_mod (* inputs )
329324 assert torch .allclose (result , torch_output , atol = 1e-4 , rtol = 1e-4 )
330325
@@ -350,17 +345,16 @@ def forward(self, source_tensor, indices_tensor, value_tensor):
350345 (source_tensor , indices_tensor , value_tensor ),
351346 dynamic_shapes = ({0 : dim1 }, {0 : dim1 }, {0 : dim2 }),
352347 )
353- with torchtrt .dynamo .Debugger (log_level = "debug" ):
354- trt_engine = torchtrt .dynamo .compile (
355- ep ,
356- inputs = (source_tensor , indices_tensor , value_tensor ),
357- enabled_precisions = {torch .float32 },
358- min_block_size = 1 ,
359- use_explicit_typing = False ,
360- use_fp32_acc = False ,
361- disable_tf32 = True ,
362- use_python_runtime = True ,
363- )
348+ trt_engine = torchtrt .dynamo .compile (
349+ ep ,
350+ inputs = (source_tensor , indices_tensor , value_tensor ),
351+ enabled_precisions = {torch .float32 },
352+ min_block_size = 1 ,
353+ use_explicit_typing = False ,
354+ use_fp32_acc = False ,
355+ disable_tf32 = True ,
356+ use_python_runtime = True ,
357+ )
364358 result = trt_engine (source_tensor , indices_tensor , value_tensor )
365359
366360 torch .allclose (result , torch_output , atol = 1e-4 , rtol = 1e-4 )
0 commit comments