Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Graph breaks in Float8Linear code #237

Closed
@y-sq

Description

@y-sq

Fullgraph compile breaks if the Float8Linear layer (the delayed scaling one) is wrapped by a torch.nn.Sequential, for example, model = nn.Sequential(Float8Linear)

The error trace,

### linear_float8
Sequential(
  (0): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (1): Sigmoid()
  (2): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (3): Sigmoid()
  (4): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (5): Sigmoid()
  (6): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (7): Sigmoid()
  (8): Float8Linear(in_features=8192, out_features=8192, bias=True)
  (9): Sigmoid()
)

... ...

Traceback (most recent call last):
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 140, in <module>
    __invoke_main()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__run_xar_main__.py", line 87, in __invoke_main
    run_as_main(main_module, main_function)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/meta_only/bootstrap.py", line 98, in run_as_main
    oss_run_as_main(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/__par__/bootstrap.py", line 94, in run_as_main
    main()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 214, in invoke_main
    fire.Fire(main)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 191, in main
    float8_forw_backward_wrapper(input_tensor)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 178, in float8_forw_backward_wrapper
    out = float8_forw_backward(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/eval_frame.py", line 450, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 923, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state, skip=1)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 400, in _convert_frame_assert
    return _compile(
  File "/usr/local/fbcode/platform010/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 676, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/utils.py", line 262, in time_wrapper
    r = func(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 535, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/bytecode_transformation.py", line 1036, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 165, in _fn
    return fn(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/convert_frame.py", line 500, in transform
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2149, in run
    super().run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/lazy.py", line 94, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 272, in call_function
    tx.call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/nn_module.py", line 716, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 489, in wrapper
    return inner_fn(self, inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1219, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 674, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 335, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 289, in call_function
    return super().call_function(tx, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/functions.py", line 90, in call_function
    return tx.inline_user_function_return(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 680, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2285, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 2399, in inline_call_
    tracer.run()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 810, in run
    and self.step()
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 773, in step
    getattr(self, inst.opname)(inst)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/symbolic_convert.py", line 1330, in STORE_ATTR
    BuiltinVariable(setattr).call_function(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 687, in call_function
    result = handler(tx, *args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/variables/builtin.py", line 1386, in call_setattr
    unimplemented(
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/_dynamo/exc.py", line 190, in unimplemented
    raise Unsupported(msg)
torch._dynamo.exc.Unsupported: setattr(UserDefinedObjectVariable) <function Module.__setattr__ at 0x7f6fc5be4790>
from user code:
   File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/profile_linear_float8.py", line 166, in float8_forw_backward
    out = linear_float8(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 290, in forward
    self.float8_pre_forward(x)
  File "/mnt/xarfuse/uid-231059/52af9f57-seed-nspid4026531836_cgpid20420211-ns-4026531841/float8_experimental/float8_linear.py", line 272, in float8_pre_forward
    self.last_seen_input_dtype = x.dtype
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

The issue is related to float8_pre_forward and can be mitigated if setting enable_pre_and_post_forward = False to skip pre/post_forward part.

The similar issue is mentioned in #172, where graph breaks if wrapped by FSDP.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions