Closed
Description
from @y-sq
Fullgraph compile breaks if the Float8Linear
layer (the delayed scaling one) is wrapped by a torch.nn.Sequential
, for example, model = nn.Sequential(Float8Linear)
The error trace,
### linear_float8
Sequential(
(0): Float8Linear(in_features=8192, out_features=8192, bias=True)
(1): Sigmoid()
(2): Float8Linear(in_features=8192, out_features=8192, bias=True)
(3): Sigmoid()
(4): Float8Linear(in_features=8192, out_features=8192, bias=True)
(5): Sigmoid()
(6): Float8Linear(in_features=8192, out_features=8192, bias=True)
(7): Sigmoid()
(8): Float8Linear(in_features=8192, out_features=8192, bias=True)
(9): Sigmoid()
)
... ...
Traceback (most recent call last):
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 140, in <module>
__invoke_main()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main
run_as_main(main_module, main_function)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main
oss_run_as_main(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main
main()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 214, in invoke_main
fire.Fire(main)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 191, in main
float8_forw_backward_wrapper(input_tensor)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 178, in float8_forw_backward_wrapper
out = float8_forw_backward(x)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/eval_frame.py", line 450, in _fn
return fn(*args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 923, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
return _compile(
File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 676, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/utils.py", line 262, in time_wrapper
r = func(*args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 535, in compile_inner
out_code = transform_code_object(code, transform)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
transformations(instructions, code_options)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 165, in _fn
return fn(*args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 500, in transform
tracer.run()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2149, in run
super().run()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
return getattr(self.realize(), name)(*args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 272, in call_function
tx.call_function(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 716, in call_function
return variables.UserFunctionVariable(fn, source=source).call_function(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
tracer.run()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
self.call_function(fn, argsvars.items, kwargsvars)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
tracer.run()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
return inner_fn(self, inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
return super().call_function(tx, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
return super().call_function(tx, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
return tx.inline_user_function_return(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
tracer.run()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
and self.step()
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
getattr(self, inst.opname)(inst)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1330, in STORE_ATTR
BuiltinVariable(setattr).call_function(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 687, in call_function
result = handler(tx, *args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 1386, in call_setattr
unimplemented(
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/exc.py", line 190, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7f6fc5be4790>
from user code:
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 166, in float8_forw_backward
out = linear_float8(x)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/nn/modules/module.py", line 1536, in _call_impl
return forward_call(*args, **kwargs)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 290, in forward
self.float8_pre_forward(x)
File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 272, in float8_pre_forward
self.last_seen_input_dtype = x.dtype
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
The issue is related to float8_pre_forward
and can be mitigated if setting enable_pre_and_post_forward = False
to skip pre/post_forward part.
The similar issue is mentioned in pytorch-labs/float8_experimental#172, where graph breaks if wrapped by FSDP.
copied from pytorch-labs/float8_experimental#237