Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 bwd #108

Draft
wants to merge 29 commits into
base: main_perf
Choose a base branch
from
Draft

fp8 bwd #108

wants to merge 29 commits into from

Conversation

micmelesse
Copy link
Collaborator

No description provided.

alexkranias-amd and others added 29 commits December 9, 2024 10:09
feat: added fp32 output to input_helper

passing

feat: fp8 tests. small amount of error

added fp8e5m2 type

note: RuntimeError: "abs_cuda" not implemented for 'Float8_e4m3fnuz'

enabled fp8 GEMMs

fix: error down to < 0.1

added another fp8 dtype

best accuracy is with no scaling

improved accuracy to within < 0.02. issue related to torch side casting

fix: passes if we allow v to be fp16 instead of fp8. otherwise we have error < 0.1

all error is < 0.07

feat: added per head scaling tensors

progress towards implementing scaling tensors in kernel

save

issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v)
Error:
UnboundLocalError: local variable 'q_scale_stride_z' referenced before
assignment.

Fix:
Initialize 'q_scale_stride_z' and 'kv_scale_stride_z' before
assignment.
Warning: I don't know if this is the correct thing to do.
Warning - 2 test cases are failing due to this change:
AssertionError: Tensor-likes are not close!

FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-True-bshd-0.0-False-4-6-6-1024-1023-32]
Mismatched elements: 1 / 786432 (0.0%)
Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed)
Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed)

FAILED test.py::test_op_prefill_fwd_impl[False-dtype1-False-bshd-0.0-False-4-6-6-1024-1023-32]
Mismatched elements: 1 / 786432 (0.0%)
Greatest absolute difference: 0.14855387806892395 at index (0, 309, 2, 18) (up to 0.1009 allowed)
Greatest relative difference: 0.28865116834640503 at index (0, 309, 2, 18) (up to 0.09128 allowed)
Two tests are still failling.
* Do not track gradients for scale factors.
* Handle maximum absolute value equals to zero in per batch / head
  scaling method.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants