Closed
Description
Bug Description
When I try compiling roberta-base
, I get this error:
---------------------------------------------------------------------------
TorchRuntimeError Traceback (most recent call last)
Cell In[2], line 17
14 input_ids = torch.stack([torch.tensor(input) for input in input_ids])
15 attention_mask = torch.ones_like(input_ids)
---> 17 model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))
File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/_compile.py:266, in compile(module, ir, inputs, arg_inputs, kwarg_inputs, enabled_precisions, **kwargs)
263 torchtrt_arg_inputs = prepare_inputs(arg_inputs)
264 torchtrt_kwarg_inputs = prepare_inputs(kwarg_inputs)
--> 266 exp_program = dynamo_trace(
267 module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
268 )
269 trt_graph_module = dynamo_compile(
270 exp_program,
271 arg_inputs=torchtrt_arg_inputs,
272 enabled_precisions=enabled_precisions_set,
273 **kwargs,
274 )
275 return trt_graph_module
File ~/dev/.venv/lib/python3.12/site-packages/torch_tensorrt/dynamo/_tracer.py:83, in trace(mod, inputs, arg_inputs, kwarg_inputs, **kwargs)
81 dynamic_shapes = get_dynamic_shapes_args(mod, arg_inputs)
82 dynamic_shapes.update(get_dynamic_shapes_kwargs(kwarg_inputs))
---> 83 exp_program = export(
84 mod,
85 tuple(torch_arg_inputs),
86 kwargs=torch_kwarg_inputs,
87 dynamic_shapes=dynamic_shapes,
88 )
90 return exp_program
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/__init__.py:270, in export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature)
264 if isinstance(mod, torch.jit.ScriptModule):
265 raise ValueError(
266 "Exporting a ScriptModule is not supported. "
267 "Maybe try converting your ScriptModule to an ExportedProgram "
268 "using `TS2EPConverter(mod, args, kwargs).convert()` instead."
269 )
--> 270 return _export(
271 mod,
272 args,
273 kwargs,
274 dynamic_shapes,
275 strict=strict,
276 preserve_module_call_signature=preserve_module_call_signature,
277 pre_dispatch=True,
278 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1017, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
1010 else:
1011 log_export_usage(
1012 event="export.error.unclassified",
1013 type=error_type,
1014 message=str(e),
1015 flags=_EXPORT_FLAGS,
1016 )
-> 1017 raise e
1018 finally:
1019 _EXPORT_FLAGS = None
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:990, in _log_export_wrapper.<locals>.wrapper(*args, **kwargs)
988 try:
989 start = time.time()
--> 990 ep = fn(*args, **kwargs)
991 end = time.time()
992 log_export_usage(
993 event="export.time",
994 metrics=end - start,
995 flags=_EXPORT_FLAGS,
996 **get_ep_stats(ep),
997 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/exported_program.py:114, in _disable_prexisiting_fake_mode.<locals>.wrapper(*args, **kwargs)
111 @functools.wraps(fn)
112 def wrapper(*args, **kwargs):
113 with unset_fake_temporarily():
--> 114 return fn(*args, **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1880, in _export(mod, args, kwargs, dynamic_shapes, strict, preserve_module_call_signature, pre_dispatch, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)
1877 # Call the appropriate export function based on the strictness of tracing.
1878 export_func = _strict_export if strict else _non_strict_export
-> 1880 export_artifact = export_func( # type: ignore[operator]
1881 mod,
1882 args,
1883 kwargs,
1884 dynamic_shapes,
1885 preserve_module_call_signature,
1886 pre_dispatch,
1887 original_state_dict,
1888 original_in_spec,
1889 allow_complex_guards_as_runtime_asserts,
1890 _is_torch_jit_trace,
1891 )
1892 export_graph_signature: ExportGraphSignature = export_artifact.aten.sig
1894 forward_arg_names = (
1895 _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None
1896 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1224, in _strict_export(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace)
1211 def _strict_export(
1212 mod: torch.nn.Module,
1213 args: Tuple[Any, ...],
(...)
1221 _is_torch_jit_trace: bool,
1222 ) -> ExportArtifact:
1223 lower_to_aten = functools.partial(_export_to_aten_ir, pre_dispatch=pre_dispatch)
-> 1224 return _strict_export_lower_to_aten_ir(
1225 mod=mod,
1226 args=args,
1227 kwargs=kwargs,
1228 dynamic_shapes=dynamic_shapes,
1229 preserve_module_call_signature=preserve_module_call_signature,
1230 pre_dispatch=pre_dispatch,
1231 original_state_dict=original_state_dict,
1232 orig_in_spec=orig_in_spec,
1233 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1234 _is_torch_jit_trace=_is_torch_jit_trace,
1235 lower_to_aten_callback=lower_to_aten,
1236 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:1252, in _strict_export_lower_to_aten_ir(mod, args, kwargs, dynamic_shapes, preserve_module_call_signature, pre_dispatch, original_state_dict, orig_in_spec, allow_complex_guards_as_runtime_asserts, _is_torch_jit_trace, lower_to_aten_callback)
1239 def _strict_export_lower_to_aten_ir(
1240 mod: torch.nn.Module,
1241 args: Tuple[Any, ...],
(...)
1250 lower_to_aten_callback: Callable,
1251 ) -> ExportArtifact:
-> 1252 gm_torch_level = _export_to_torch_ir(
1253 mod,
1254 args,
1255 kwargs,
1256 dynamic_shapes,
1257 preserve_module_call_signature=preserve_module_call_signature,
1258 restore_fqn=False, # don't need to restore because we will do it later
1259 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
1260 _log_export_usage=False,
1261 )
1263 # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo.
1264 (
1265 fake_args,
1266 fake_kwargs,
1267 dynamo_fake_mode,
1268 ) = _extract_fake_inputs(gm_torch_level, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/export/_trace.py:560, in _export_to_torch_ir(f, args, kwargs, dynamic_shapes, preserve_module_call_signature, disable_constraint_solver, allow_complex_guards_as_runtime_asserts, restore_fqn, _log_export_usage, same_signature)
556 module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
557 with _wrap_submodules(
558 f, preserve_module_call_signature, module_call_specs
559 ), _ignore_backend_decomps():
--> 560 gm_torch_level, _ = torch._dynamo.export(
561 f,
562 dynamic_shapes=transformed_dynamic_shapes, # type: ignore[arg-type]
563 tracing_mode="symbolic",
564 disable_constraint_solver=disable_constraint_solver,
565 # currently the following 2 flags are tied together for export purposes,
566 # but untangle for sake of dynamo export api
567 prefer_deferred_runtime_asserts_over_guards=True,
568 allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts,
569 _log_export_usage=_log_export_usage,
570 same_signature=same_signature,
571 )(
572 *args,
573 **kwargs,
574 )
575 except (ConstraintViolationError, ValueRangeError) as e:
576 raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:1432, in export.<locals>.inner(*args, **kwargs)
1430 # TODO(voz): We may have instances of `f` that mutate inputs, we should track sideeffects and reject.
1431 try:
-> 1432 result_traced = opt_f(*args, **kwargs)
1433 except ConstraintViolationError as e:
1434 constraint_violation_error = e
File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:465, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
460 saved_dynamic_layer_stack_depth = (
461 torch._C._functorch.get_dynamic_layer_stack_depth()
462 )
464 try:
--> 465 return fn(*args, **kwargs)
466 finally:
467 # Restore the dynamic layer stack depth if necessary.
468 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
469 saved_dynamic_layer_stack_depth
470 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1269, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
1263 return hijacked_callback(
1264 frame, cache_entry, self.hooks, frame_state
1265 )
1267 with compile_lock, _disable_current_modes():
1268 # skip=1: skip this frame
-> 1269 return self._torchdynamo_orig_callable(
1270 frame, cache_entry, self.hooks, frame_state, skip=1
1271 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:526, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
510 compile_id = CompileId(frame_id, frame_compile_id)
512 signpost_event(
513 "dynamo",
514 "_convert_frame_assert._compile",
(...)
523 },
524 )
--> 526 return _compile(
527 frame.f_code,
528 frame.f_globals,
529 frame.f_locals,
530 frame.f_builtins,
531 self._torchdynamo_orig_callable,
532 self._one_graph,
533 self._export,
534 self._export_constraints,
535 hooks,
536 cache_entry,
537 cache_size,
538 frame,
539 frame_state=frame_state,
540 compile_id=compile_id,
541 skip=skip + 1,
542 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:924, in _compile(code, globals, locals, builtins, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
922 guarded_code = None
923 try:
--> 924 guarded_code = compile_inner(code, one_graph, hooks, transform)
925 return guarded_code
926 except Exception as e:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:666, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
664 with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"):
665 with CompileTimeInstructionCounter.record():
--> 666 return _compile_inner(code, one_graph, hooks, transform)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_utils_internal.py:87, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
84 kwargs["skip"] = kwargs["skip"] + 1
86 if not StrobelightCompileTimeProfiler.enabled:
---> 87 return function(*args, **kwargs)
89 return StrobelightCompileTimeProfiler.profile_compile_time(
90 function, phase_name, *args, **kwargs
91 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:699, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
697 CompileContext.get().attempt = attempt
698 try:
--> 699 out_code = transform_code_object(code, transform)
700 break
701 except exc.RestartAnalysis as e:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1322, in transform_code_object(code, transformations, safe)
1319 instructions = cleaned_instructions(code, safe)
1320 propagate_line_nums(instructions)
-> 1322 transformations(instructions, code_options)
1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:219, in preserve_global_state.<locals>._fn(*args, **kwargs)
215 exit_stack.enter_context(
216 torch.fx._symbolic_trace._maybe_revert_all_patches()
217 )
218 try:
--> 219 return fn(*args, **kwargs)
220 finally:
221 cleanup.close()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:634, in _compile.<locals>.transform(instructions, code_options)
632 try:
633 with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 634 tracer.run()
635 except exc.UnspecializeRestartAnalysis:
636 speculation_log.clear()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2796, in InstructionTranslator.run(self)
2795 def run(self):
-> 2796 super().run()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
892 self.update_block_stack(inst)
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
585 # We don't support graph break under GenericContextWrappingVariable,
586 # If there is, we roll back to the checkpoint and fall back.
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
2277 @break_graph_if_unsupported(push=1)
2278 def CALL(self, inst):
-> 2279 self._call(inst)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
2268 kwargs = {}
2270 try:
2271 # if call_function fails, need to set kw_names to None, otherwise
2272 # a subsequent call may have self.kw_names set to an old value
-> 2273 self.call_function(fn, args, kwargs)
2274 finally:
2275 self.kw_names = None
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
440 else:
441 assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
443 variables.UserFunctionVariable(fn, source=fn_source),
444 args,
445 kwargs,
446 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
832 def inline_user_function_return(self, fn, args, kwargs):
833 """
834 A call to some user defined function by inlining it.
835 """
--> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3008 @classmethod
3009 def inline_call(cls, parent, func, args, kwargs):
3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011 return cls.inline_call_(parent, func, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3137 try:
3138 with strict_ctx:
-> 3139 tracer.run()
3140 except exc.ObservedException as e:
3141 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
892 self.update_block_stack(inst)
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
585 # We don't support graph break under GenericContextWrappingVariable,
586 # If there is, we roll back to the checkpoint and fall back.
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
1678 # Map to a dictionary of str -> VariableTracker
1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
383 fn = getattr(self.obj.value, self.fn.__name__)
384 return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
319 if self.is_constant:
320 return invoke_and_store_as_constant(
321 tx, self.fn, self.get_name(), args, kwargs
322 )
--> 324 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
105 def call_function(
106 self,
107 tx: "InstructionTranslator",
108 args: "List[VariableTracker]",
109 kwargs: "Dict[str, VariableTracker]",
110 ) -> "VariableTracker":
--> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
832 def inline_user_function_return(self, fn, args, kwargs):
833 """
834 A call to some user defined function by inlining it.
835 """
--> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3008 @classmethod
3009 def inline_call(cls, parent, func, args, kwargs):
3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011 return cls.inline_call_(parent, func, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3137 try:
3138 with strict_ctx:
-> 3139 tracer.run()
3140 except exc.ObservedException as e:
3141 msg = f"Observed exception DURING INLING {code} : {e}"
[... skipping similar frames: InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
2277 @break_graph_if_unsupported(push=1)
2278 def CALL(self, inst):
-> 2279 self._call(inst)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
2268 kwargs = {}
2270 try:
2271 # if call_function fails, need to set kw_names to None, otherwise
2272 # a subsequent call may have self.kw_names set to an old value
-> 2273 self.call_function(fn, args, kwargs)
2274 finally:
2275 self.kw_names = None
[... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
440 else:
441 assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
443 variables.UserFunctionVariable(fn, source=fn_source),
444 args,
445 kwargs,
446 )
[... skipping similar frames: InliningInstructionTranslator.inline_call at line 3011 (1 times), InliningInstructionTranslator.inline_call_ at line 3139 (1 times), InstructionTranslatorBase.inline_user_function_return at line 836 (1 times), InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
1678 # Map to a dictionary of str -> VariableTracker
1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)
[... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
383 fn = getattr(self.obj.value, self.fn.__name__)
384 return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
319 if self.is_constant:
320 return invoke_and_store_as_constant(
321 tx, self.fn, self.get_name(), args, kwargs
322 )
--> 324 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
105 def call_function(
106 self,
107 tx: "InstructionTranslator",
108 args: "List[VariableTracker]",
109 kwargs: "Dict[str, VariableTracker]",
110 ) -> "VariableTracker":
--> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
[... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (5 times), InliningInstructionTranslator.inline_call at line 3011 (5 times), InliningInstructionTranslator.inline_call_ at line 3139 (5 times), InstructionTranslatorBase.inline_user_function_return at line 836 (5 times), InstructionTranslatorBase.run at line 983 (5 times), InstructionTranslatorBase.step at line 895 (5 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (5 times), InstructionTranslatorBase.CALL at line 2279 (3 times), InstructionTranslatorBase._call at line 2273 (3 times), InstructionTranslatorBase.CALL_FUNCTION_EX at line 1680 (2 times), NNModuleVariable.call_function at line 442 (2 times), UserMethodVariable.call_function at line 385 (2 times), UserFunctionVariable.call_function at line 324 (2 times), BaseUserFunctionVariable.call_function at line 111 (2 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:442, in NNModuleVariable.call_function(self, tx, args, kwargs)
440 else:
441 assert istype(fn, types.FunctionType)
--> 442 return tx.inline_user_function_return(
443 variables.UserFunctionVariable(fn, source=fn_source),
444 args,
445 kwargs,
446 )
[... skipping similar frames: InliningInstructionTranslator.inline_call at line 3011 (1 times), InliningInstructionTranslator.inline_call_ at line 3139 (1 times), InstructionTranslatorBase.inline_user_function_return at line 836 (1 times), InstructionTranslatorBase.run at line 983 (1 times), InstructionTranslatorBase.step at line 895 (1 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 582 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1680, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
1678 # Map to a dictionary of str -> VariableTracker
1679 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1680 self.call_function(fn, argsvars.items, kwargsvars)
[... skipping similar frames: InstructionTranslatorBase.call_function at line 830 (1 times)]
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:385, in UserMethodVariable.call_function(self, tx, args, kwargs)
383 fn = getattr(self.obj.value, self.fn.__name__)
384 return invoke_and_store_as_constant(tx, fn, self.get_name(), args, kwargs)
--> 385 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:324, in UserFunctionVariable.call_function(self, tx, args, kwargs)
319 if self.is_constant:
320 return invoke_and_store_as_constant(
321 tx, self.fn, self.get_name(), args, kwargs
322 )
--> 324 return super().call_function(tx, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:111, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
105 def call_function(
106 self,
107 tx: "InstructionTranslator",
108 args: "List[VariableTracker]",
109 kwargs: "Dict[str, VariableTracker]",
110 ) -> "VariableTracker":
--> 111 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:836, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
832 def inline_user_function_return(self, fn, args, kwargs):
833 """
834 A call to some user defined function by inlining it.
835 """
--> 836 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3011, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3008 @classmethod
3009 def inline_call(cls, parent, func, args, kwargs):
3010 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3011 return cls.inline_call_(parent, func, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3139, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3137 try:
3138 with strict_ctx:
-> 3139 tracer.run()
3140 except exc.ObservedException as e:
3141 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:983, in InstructionTranslatorBase.run(self)
981 try:
982 self.output.push_tx(self)
--> 983 while self.step():
984 pass
985 except BackendCompilerFailed:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:895, in InstructionTranslatorBase.step(self)
892 self.update_block_stack(inst)
894 try:
--> 895 self.dispatch_table[inst.opcode](self, inst)
896 return not self.output.should_exit
897 except exc.ObservedException as e:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:582, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
580 return handle_graph_break(self, inst, speculation.reason)
581 try:
--> 582 return inner_fn(self, inst)
583 except Unsupported as excp:
584 if self.generic_context_manager_depth > 0:
585 # We don't support graph break under GenericContextWrappingVariable,
586 # If there is, we roll back to the checkpoint and fall back.
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2279, in InstructionTranslatorBase.CALL(self, inst)
2277 @break_graph_if_unsupported(push=1)
2278 def CALL(self, inst):
-> 2279 self._call(inst)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2273, in InstructionTranslatorBase._call(self, inst, call_kw)
2268 kwargs = {}
2270 try:
2271 # if call_function fails, need to set kw_names to None, otherwise
2272 # a subsequent call may have self.kw_names set to an old value
-> 2273 self.call_function(fn, args, kwargs)
2274 finally:
2275 self.kw_names = None
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:830, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
828 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
829 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 830 self.push(fn.call_function(self, args, kwargs))
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py:897, in TorchInGraphFunctionVariable.call_function(self, tx, args, kwargs)
888 if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
889 # Calling fake tensor propagation can mutate the out= tensor in
890 # tx.output.tracked_fakes. tracked_fakes are used to apply
(...)
893 # guards. So save the shape now, and check later if it has
894 # changed. If it has, graph break.
895 fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
--> 897 tensor_variable = wrap_fx_proxy(
898 tx=tx,
899 proxy=tx.output.create_proxy(
900 "call_function",
901 fn_,
902 *proxy_args_kwargs(args, kwargs),
903 ),
904 )
906 if (
907 isinstance(tensor_variable, TensorVariable)
908 and "requires_grad" in kwargs
909 and kwargs["requires_grad"].as_python_constant()
910 ):
911 unimplemented(
912 """factory functions that return tensors that require grad are not supported.
913 Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
914 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2037, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
2029 kwargs = {
2030 "tx": tx,
2031 "proxy": proxy,
(...)
2034 **options,
2035 }
2036 if subclass_type is None:
-> 2037 return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
2038 else:
2039 result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2124, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
2119 with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2120 # with preserve_rng_state():
2121 if example_value is None:
2122 # only allow_non_graph_fake in this instance because we handle the non-fake
2123 # cases properly below.
-> 2124 example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
2126 # Handle recursive calls here
2127 elif maybe_get_fake_mode(example_value) is tx.fake_mode:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2082, in get_fake_value(node, tx, allow_non_graph_fake)
2079 elif isinstance(cause, TypeError) and "argument" in str(cause):
2080 unimplemented(f"TypeError {node.target}: {cause}")
-> 2082 raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
2084 if not allow_non_graph_fake:
2085 _ = pytree.tree_map_only(
2086 torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
2087 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2017, in get_fake_value(node, tx, allow_non_graph_fake)
2015 try:
2016 with tx.fake_mode, enable_python_dispatcher():
-> 2017 ret_val = wrap_fake_exception(
2018 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
2019 )
2020 except Unsupported:
2021 raise
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:1574, in wrap_fake_exception(fn)
1572 def wrap_fake_exception(fn):
1573 try:
-> 1574 return fn()
1575 except UnsupportedFakeTensorException as e:
1576 from .exc import unimplemented
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2018, in get_fake_value.<locals>.<lambda>()
2015 try:
2016 with tx.fake_mode, enable_python_dispatcher():
2017 ret_val = wrap_fake_exception(
-> 2018 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
2019 )
2020 except Unsupported:
2021 raise
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2150, in run_node(tracer, node, args, kwargs, nnmodule)
2148 unimplemented(make_error_message(e), from_exc=e)
2149 except Exception as e:
-> 2150 raise RuntimeError(make_error_message(e)).with_traceback(
2151 e.__traceback__
2152 ) from e
2154 raise AssertionError(op)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_dynamo/utils.py:2132, in run_node(tracer, node, args, kwargs, nnmodule)
2130 try:
2131 if op == "call_function":
-> 2132 return node.target(*args, **kwargs)
2133 elif op == "call_method":
2134 return getattr(args[0], node.target)(*args[1:], **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_stats.py:21, in count.<locals>.wrapper(*args, **kwargs)
19 simple_call_counter[fn.__qualname__] = 0
20 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 21 return fn(*args, **kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1238, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
1234 assert (
1235 torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
1236 ), func
1237 try:
-> 1238 return self.dispatch(func, types, args, kwargs)
1239 except TypeError:
1240 log.exception("fake tensor raised TypeError")
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1692, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
1689 return func(*args, **kwargs)
1691 if self.cache_enabled:
-> 1692 return self._cached_dispatch_impl(func, types, args, kwargs)
1693 else:
1694 return self._dispatch_impl(func, types, args, kwargs)
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1339, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
1337 else:
1338 self._validate_cache_key(func, args, kwargs)
-> 1339 output = self._dispatch_impl(func, types, args, kwargs)
1340 entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1341 key.strip_shape_env()
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2021, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
2017 log.exception("failed while attempting to run meta for %s", func)
2018 raise
2020 return maybe_propagate_real_tensors(
-> 2021 self.wrap_meta_outputs_with_default_device_logic(
2022 r, func, flat_args, device=kwargs.get("device")
2023 )
2024 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2143, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device)
2140 else:
2141 return e
-> 2143 return tree_map(wrap, r)
File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_pytree.py:964, in tree_map(func, tree, is_leaf, *rests)
962 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
963 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 964 return treespec.unflatten(map(func, *flat_args))
File ~/dev/.venv/lib/python3.12/site-packages/torch/utils/_pytree.py:803, in TreeSpec.unflatten(self, leaves)
801 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
802 if not isinstance(leaves, (list, tuple)):
--> 803 leaves = list(leaves)
804 if len(leaves) != self.num_leaves:
805 raise ValueError(
806 f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
807 f"but the spec refers to a pytree that holds {self.num_leaves} "
808 f"items ({self}).",
809 )
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2121, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic.<locals>.wrap(e)
2115 return e
2117 if common_device is None:
2118 (
2119 common_device,
2120 has_scalar_only_inputs,
-> 2121 ) = FakeTensor._find_common_device(func, flat_args)
2123 is_our_fake = self.is_our_fake(e)
2124 if is_our_fake:
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:872, in FakeTensor._find_common_device(func, flat_args)
867 raise RuntimeError(
868 f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
869 )
871 for arg in flat_args:
--> 872 merge_devices(arg)
874 # some functions that allow Python numbers to bind to Tensors
875 # if we have failed to find a device, and we're running one of these operators,
876 # we must have scalar only inputs
877 if should_allow_numbers_as_tensors(func) and common_device is None:
878 # ops with scalar only inputs always have result on cpu
File ~/dev/.venv/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:867, in FakeTensor._find_common_device.<locals>.merge_devices(t)
863 return
865 # mismatching devices of non-zero dim tensors, throw
866 # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
--> 867 raise RuntimeError(
868 f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
869 )
TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>), FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>), FakeTensor(..., size=(128, 12, 4, 64), grad_fn=<PermuteBackward0>)), **{'attn_mask': FakeTensor(..., device='cuda:0', size=(128, 1, 4, 4)), 'dropout_p': 0.0, 'is_causal': False}):
Unhandled FakeTensor Device Propagation for aten._scaled_dot_product_flash_attention_for_cpu.default, found two different devices cpu, cuda:0
from user code:
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 1318, in forward
outputs = self.roberta(
File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 976, in forward
encoder_outputs = self.encoder(
File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 631, in forward
layer_outputs = layer_module(
File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 520, in forward
self_attention_outputs = self.attention(
File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 447, in forward
self_outputs = self.self(
File "/home/dev/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/dev/.venv/lib/python3.12/site-packages/transformers/models/roberta/modeling_roberta.py", line 370, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
To Reproduce
Run:
import torch
import torch_tensorrt
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# BEGIN CONFIG #
MODEL_DIR = f'roberta-base'
# END CONFIG #
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, attn_implementation = 'sdpa')
model = model.to('cuda')
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
input_ids = [tokenizer.encode('Hello world')] * 128
input_ids = torch.stack([torch.tensor(input) for input in input_ids]).to('cuda')
attention_mask = torch.ones_like(input_ids).to('cuda')
model = torch_tensorrt.compile(model, inputs = (input_ids, attention_mask))
Expected behavior
The compilation works.
Environment
WSL 2, Torch-TensorRT version 2.5.0, PyTorch verison 2.5.1, CUDA 12.4, Python 3.12.5