Skip to content

Graph breaks in Float8Linear code #563

Closed
@vkuzo

Description

@vkuzo

from @y-sq

Fullgraph compile breaks if the Float8Linear layer (the delayed scaling one) is wrapped by a torch.nn.Sequential, for example, model = nn.Sequential(Float8Linear)

The error trace,

### linear_float8
Sequential(
  (0): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (1): Sigmoid()
  (2): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (3): Sigmoid()
  (4): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (5): Sigmoid()
  (6): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (7): Sigmoid()
  (8): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (9): Sigmoid()
)

... ...

Traceback (most recent call last):
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 140, in <module>
    __invoke_main()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main
    run_as_main(main_module, main_function)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main
    oss_run_as_main(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main
    main()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 214, in invoke_main
    fire.Fire(main)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 191, in main
    float8_forw_backward_wrapper(input_tensor)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 178, in float8_forw_backward_wrapper
    out = float8_forw_backward(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/eval_frame.py", line 450, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 923, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 272, in call_function
    tx.call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 716, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1330, in STORE_ATTR
    BuiltinVariable(setattr).call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 687, in call_function
    result = handler(tx, *args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 1386, in call_setattr
    unimplemented(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7f6fc5be4790>
from user code:
   File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 166, in float8_forw_backward
    out = linear_float8(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 290, in forward
    self.float8_pre_forward(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 272, in float8_pre_forward
    self.last_seen_input_dtype = x.dtype
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The issue is related to float8_pre_forward and can be mitigated if setting enable_pre_and_post_forward = False to skip pre/post_forward part.

The similar issue is mentioned in pytorch-labs/float8_experimental#172, where graph breaks if wrapped by FSDP.

copied from pytorch-labs/float8_experimental#237

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions