Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Torch Compile with Float8Linear #106

Closed
@drisspg

Description

@drisspg

Summary

I will be using this as a top level tracker and link to subissues with smaller repros to tackle this problem

PRs

#56 Brian has done some initial work getting subclasses to compile for fp8

Issues

Problem summaries

All the problems are based off of this implementation of Float8Tensor
#128

Add this repro script to surface compile issues: https://gist.github.com/drisspg/6e76d3d99dc932e2287f19123f6339d1

Backend = "eager"

  1. When attempting to compile FP8Linear w/ "eager" backend. We currently fail during the automatic_dynamic_dims creation problems. For a more detailed analysis and potential fix see: Allow traceable_subclass_tensors to have multiple dynamic tensor attributes pytorch/pytorch#112185
  2. After the above PR there are no hard errors but compiling w/ backend = eager gives the following two warnings
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [ERROR] Unexpected type in sourceless builder <class 'torch.dtype'>

Adding the following to sourceless builder

    elif isinstance(value, torch.dtype):
        return ConstantVariable.create(value)

PR: pytorch/pytorch#112284
Cleans up both errors.

Graph Breaks
I so using TORCH_LOGS="graph_breaks" python ../scripts/fp8/eager_compile_debug.py we were graphbreaks whenever we tried to construct fp8_tensors with the class method. I found out that moving it to a function fixed the graph breaks and now we we have None for this script, see:
#131

Backend = "aot_eager"

With the fix to no have any graph breaks we now get a more helpful error message:

Traceback (most recent call last):
  File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 41, in <module>
    main()
  File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 31, in main
    y_fp8.sum().backward()
  File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 254, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4009, in backward
    assert grad_output_types_ == CompiledFunction.metadata.output_types, f"""\
AssertionError: We incorrectly attempted to compile the backward with incorrect subclass metadata.
If you run into this error, please file an issue.
Expected grad_output types: [<class 'float8_experimental.float8_tensor.Float8Tensor'>]
Got grad_output types: [<class 'torch.Tensor'>]

I suspect this error is because for matmul we output a regular tensor and not a TensorSubclass. And then during backward we have the autograd func that converts it to the different fp8 format

Backend = "inductor"

With the tangle of PRs and changes and by not running backwards on the subclass linear I can actually compile with inductor!

However it fails when the "high_precision" dytpe is not float32. I suspect this is because we are storing amax in fp32 (needed for scaled_mm) and inductor scatter produces the following error

  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 289, in wrapped
    out = decomp_fn(*args, **kwargs)
  File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 2219, in select_scatter
    assert x.get_dtype() == src.get_dtype()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
  target: aten.select_scatter.default
  args[0]: TensorBox(StorageBox(
    InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[16], stride=[1]))
  ))
  args[1]: TensorBox(StorageBox(
    ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda', torch.bfloat16, size=[], stride=[]), data=Reduction(
      'cuda',
      torch.bfloat16,
      def inner_fn(index, rindex):
          r0, r1 = rindex
          tmp0 = ops.load(primals_1, r1 + 16 * r0)
          tmp1 = ops.abs(tmp0)
          return tmp1
      ,
      ranges=[],
      reduction_ranges=[16, 16],
      reduction_type=max,
      origin_node=max_1,
      origins={max_1, abs_1}
    ))
  ))
  args[2]: 0
  args[3]: 0
Old error:

When attempting the compile for "aot_eager" with the above two fixes we get

  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4230, in <listcomp>
    return [convert(idx, x) for idx, x in enumerate(flat_args)]
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4219, in convert
    assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
AssertionError:

UPDATE: I was able to trigger a more helpful error message by iterating through the fake modes of the inner tensors:
https://gist.github.com/drisspg/ed916d144e819d7eb0be6728e0e807a7

Metadata

Metadata

Assignees

No one assigned

    Labels

    CompileIssues related with subclass compilation

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions