Torch Compile with Float8Linear #106
Description
Summary
I will be using this as a top level tracker and link to subissues with smaller repros to tackle this problem
PRs
#56 Brian has done some initial work getting subclasses to compile for fp8
Issues
- [Compile] Error with FullGraph "eager" compile on TensorSubclass version #108
- [Compile] Error with compile for aot_eager Tensor sublcass #117
Problem summaries
All the problems are based off of this implementation of Float8Tensor
#128
Add this repro script to surface compile issues: https://gist.github.com/drisspg/6e76d3d99dc932e2287f19123f6339d1
Backend = "eager"
- When attempting to compile FP8Linear w/ "eager" backend. We currently fail during the automatic_dynamic_dims creation problems. For a more detailed analysis and potential fix see: Allow traceable_subclass_tensors to have multiple dynamic tensor attributes pytorch/pytorch#112185
- After the above PR there are no hard errors but compiling w/ backend = eager gives the following two warnings
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [WARNING] speculate_subgraph: while introspecting the user-defined autograd.Function, we were unable to trace function `trampoline_autograd_fwd` into a single graph. This means that Dynamo was unable to prove safety for this API and will fall back to eager-mode PyTorch, which could lead to a slowdown.
[2023-10-27 09:36:36,705] [2/1] torch._dynamo.variables.higher_order_ops: [ERROR] Unexpected type in sourceless builder <class 'torch.dtype'>
Adding the following to sourceless builder
elif isinstance(value, torch.dtype):
return ConstantVariable.create(value)
PR: pytorch/pytorch#112284
Cleans up both errors.
Graph Breaks
I so using TORCH_LOGS="graph_breaks" python ../scripts/fp8/eager_compile_debug.py we were graphbreaks whenever we tried to construct fp8_tensors with the class method. I found out that moving it to a function fixed the graph breaks and now we we have None for this script, see:
#131
Backend = "aot_eager"
With the fix to no have any graph breaks we now get a more helpful error message:
Traceback (most recent call last):
File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 41, in <module>
main()
File "/home/drisspg/meta/float8_experimental/../scripts/fp8/eager_compile_debug.py", line 31, in main
y_fp8.sum().backward()
File "/home/drisspg/meta/pytorch/torch/_tensor.py", line 503, in backward
torch.autograd.backward(
File "/home/drisspg/meta/pytorch/torch/autograd/__init__.py", line 254, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/drisspg/meta/pytorch/torch/autograd/function.py", line 289, in apply
return user_fn(self, *args)
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4009, in backward
assert grad_output_types_ == CompiledFunction.metadata.output_types, f"""\
AssertionError: We incorrectly attempted to compile the backward with incorrect subclass metadata.
If you run into this error, please file an issue.
Expected grad_output types: [<class 'float8_experimental.float8_tensor.Float8Tensor'>]
Got grad_output types: [<class 'torch.Tensor'>]
I suspect this error is because for matmul we output a regular tensor and not a TensorSubclass. And then during backward we have the autograd func that converts it to the different fp8 format
Backend = "inductor"
With the tangle of PRs and changes and by not running backwards on the subclass linear I can actually compile with inductor!
However it fails when the "high_precision" dytpe is not float32. I suspect this is because we are storing amax in fp32 (needed for scaled_mm) and inductor scatter produces the following error
File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 289, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/drisspg/meta/pytorch/torch/_inductor/lowering.py", line 2219, in select_scatter
assert x.get_dtype() == src.get_dtype()
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AssertionError:
target: aten.select_scatter.default
args[0]: TensorBox(StorageBox(
InputBuffer(name='primals_3', layout=FixedLayout('cuda', torch.float32, size=[16], stride=[1]))
))
args[1]: TensorBox(StorageBox(
ComputedBuffer(name='buf0', layout=FlexibleLayout('cuda', torch.bfloat16, size=[], stride=[]), data=Reduction(
'cuda',
torch.bfloat16,
def inner_fn(index, rindex):
r0, r1 = rindex
tmp0 = ops.load(primals_1, r1 + 16 * r0)
tmp1 = ops.abs(tmp0)
return tmp1
,
ranges=[],
reduction_ranges=[16, 16],
reduction_type=max,
origin_node=max_1,
origins={max_1, abs_1}
))
))
args[2]: 0
args[3]: 0
Old error:
When attempting the compile for "aot_eager" with the above two fixes we get
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4230, in <listcomp>
return [convert(idx, x) for idx, x in enumerate(flat_args)]
File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4219, in convert
assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
AssertionError:
UPDATE: I was able to trigger a more helpful error message by iterating through the fake modes of the inner tensors:
https://gist.github.com/drisspg/ed916d144e819d7eb0be6728e0e807a7