Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

[not for land] move scaled matmul logic to __torch_dispatch__ #28

Closed
wants to merge 52 commits into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Aug 11, 2023

Summary:

This is a quick prototype of what the float8 matmul logic would look like if we moved it from a torch.autograd.Function subclass to Float8Tensor.__torch_dispatch__.

Pros:

  • if we go with this UX, we can eventually remove Float8Linear. Note that the casts will have to be moved to module hooks.
  • we get the reshape logic (translating from torch.matmul to torch.mm) for free instead of having to reimplement it

Cons:

  • code is harder to follow compared to having a familiar torch.autograd.Function which handles the state management (arguably)
  • because of delayed scaling, we need to attach buffer references to Float8Tensor objects, this kinda works but deviates from how people usually use __torch_dispatch__

At this point I do not expect to land this any time soon as keeping the code as is sounds simpler. Putting this up in case we want to discuss this / revisit this later.

Test Plan:

python tests/test.py -k test_linear

Reviewers:

Subscribers:

Tasks:

Tags:

vkuzo and others added 30 commits July 19, 2023 15:32
Summary:

This is a copy of
facebookexperimental/protoquant#23

Many things will change based on recent discussions!

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Python-only float8 data type + bare bones UEX
Summary:

skipped it before, going back to it now

this will be useful for transformer block

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add bias support to float8linear
Summary:

forgot on last PR

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Now that pytorch/pytorch#104242 landed, we can
stop emulation - this simplifies the code quite a bit.

Test Plan:

```
python float8_playground/test.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
switch from emulated to real float8 dtypes
Summary:

Adds a check that using `Float8Linear` on a `SAM` encoder
results in reasonable spot check accuracy.

Note that grad accuracy is not tested as it is all over the place,
this is probably expected but saving investigation until later.
Specifically the layernorm and positional encoding grads have
a large error for fp8 version vs reference.

Test Plan:

```
python float8_playground/test.py
with-proxy python float8_playground/test_sam.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add numerical test on SAM encoder
Summary:

Don't test cpu for now to maximize dev speed, we should add that back
later.

Requires pytorch/pytorch#105807

Test Plan:

```
python float8_playground/test.py
python float8_playground/test_sam.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Just having a copy of this in a script vs notebook is useful

Test Plan:

```
python te_examples/quickstart_guide.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
add a repro of TransformerEngine's quickstart guide
Summary:

Creates a simple image classification network and finetunes it
on MNIST.  Baseline is fp32, and fp8 training can be enabled with
a flag.  Verified that fp8 training converges on this simple example.
Note that fp8 compute is emulated for now as we don't have a hookup
to the real fp8 matmul kernel yet.

Test Plan:

```
with-proxy python finetune/mnist.py --batch-size 4096
// https://gist.github.com/vkuzo/0e8cbb3df1f0610e528ac3ad15da3ace
with-proxy python finetune/mnist.py --batch-size 4096 --use-pt-fp
// https://gist.github.com/vkuzo/99b0cf2c1492a5f605c9f028f12340c3
```

Reviewers:

Subscribers:

Tasks:

Tags:
add simple finetuning check with fp8 emulation
Summary:

Adds a test for numerical equivalence of single GPU vs FSDP for a toy
model.

Note: this is not related to fp8 yet, a future PR will add a test that
this still holds for fp8.

Test Plan:

```
./float8_playground/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

```
python tests/test.py
with-proxy python tests/test_sam.py
./tests/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
refactor tests into separate dir
Summary:

Note that we can't match the weight gradient and its scale
because of gradient reduction.

Test Plan:

```
./tests/test_fsdp.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
clean up float8 differentiable constructor
Summary:

before: Float8Tensor always converted back to float32
after: Float8Tensor remembers original dtype

this will be useful for autocast support

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Make Float8Tensor remember original precision
Summary:

Adding grads and casting back to fp8 doesn't have a clear
use case yet, we can this back if needed. For now, simplify

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
make grad addition happen in original precision
Summary:

We just duplicate autocast logic for F.linear to have Float8Linear
do the right thing.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
vkuzo and others added 22 commits August 4, 2023 14:28
enable autocast support for single GPU
Summary:

Before: scale buffers were stored, amax calculation was hidden
After: amax buffers are stored, scale calculation is stateless

This is a refactor to make it easier to enabled delayed scaling.
No functionality change in this PR, just refactor.

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
switch from scale buffers to amax buffers
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Before: all scaling was done just-in-time
After:
1. scaling is done in a delayed fashion with a history of 1
2. there is special logic to populate initial amaxes (TE doesn't have
this)

A future PR will add windowed calculation

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
switch from just-in-time scaling to delayed scaling
Summary:

this simplifies some of the code in `Float8Linear`, which is nice
for keeping things readable

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
make Float8Tensor reshape'able and t'able
Summary:

we can use `__torch_dispatch__` to convert to original precision
automatically

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
simplify reference output calculation
Summary:

Adds a convenience python api to wrap the aten float8 api.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
add python api to wrap aten api
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Bias in fp8 is not supported according to
https://docs.nvidia.com/cuda/cublas/#id99 , remove it from
this codebase to simplify

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

infiniband used to work but no longer does, for now work around

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

This is cleaner and will enable an easier test of moving
`float8_linear` logic to `__torch_dispatch__`

Test Plan:

```
with-proxy ./tests/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
move fp8 cast of `grad_out` out to `Float8Linear`
Summary:

Test Plan:

```
python tests/test.py -k test_linear
```

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 11, 2023
@vkuzo
Copy link
Contributor Author

vkuzo commented Aug 11, 2023

cc @albanD , @drisspg , @bdhirsh , this is a quick prototype of a more torch_dispatch'ey Float8 code. @albanD and I discussed an earlier version of this offline and for now I'm planning to stick with torch.autograd.Function

@vkuzo vkuzo changed the title [wip] move scaled matmul logic to __torch_dispatch__ [not for land] move scaled matmul logic to __torch_dispatch__ Aug 11, 2023
@drisspg drisspg mentioned this pull request Oct 21, 2023
2 tasks
facebook-github-bot pushed a commit that referenced this pull request Nov 1, 2023
Summary:
We use the dispatching mechanism to mm

## TODO
- [x] Hook on to float8_tensor the amax_buffer to be filled under dispatch
- [x] Update emulate path

# Note
Vasiliy has already started this here:
#28

Some things have changed though since then, we are outputing by default in higher precision. However I still need to replicate the amax_buffer filling here and store on float8_tensor passed in

Corresponding core changes to get as far as possible in compile for aot_eager
pytorch/pytorch#111735

``` Shell
Checking against fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13cd1bd0>
attr=_data
attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00>
attr=_scale
attr_fake_mode=<torch._subclasses.fake_tensor.FakeTensorMode object at 0x7f4c13c34d00>
```

### Current Compile Progress
- backend = "eager_only", full_graph = False: ✅
- backend = "eager_only", full_graph = False: ❌
``` Shell
E       torch._dynamo.exc.Unsupported: call_function UserDefinedObjectVariable(to_float8) [TensorVariable(), TensorVariable(), ConstantVariable(dtype), TensorVariable()] {}
```
- backend = "aot_eager", full_graph = False: ❌
``` Shell
  File "/home/drisspg/meta/pytorch/torch/_functorch/aot_autograd.py", line 4187, in convert
    assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
torch._dynamo.exc.BackendCompilerFailed: backend='aot_eager' raised:
AssertionError:
```

Pull Request resolved: #128

Reviewed By: bdhirsh, y-sq

Differential Revision: D50901900

Pulled By: drisspg

fbshipit-source-id: 64626bc652b70bfbabff2ab26e999324d1463e1d
@drisspg drisspg closed this Nov 16, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants