-
Notifications
You must be signed in to change notification settings - Fork 125
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ClipQKV #197
ClipQKV #197
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gonna leave this to others to review + approve, but could I suggest adding some unit tests to check equality between the torch/causal/triton attention blocks?
Co-authored-by: Abhi Venigalla <77638579+abhi-mosaic@users.noreply.github.com>
created issue: https://mosaicml.atlassian.net/browse/RESEARCH-468 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
nvm, i misread the output. looks like your added tests are failing |
yeah they're gpu tests. I need to xfail them on cpu. Forgot to do that... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM after gating the tests to be on GPU only
1b02a56
to
7b6f3ae
Compare
49392da
to
62d9749
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
This PR
qk_ln
in triton attn variantOn its own, Alibi does not fully solve the stability issue; when clipping is added to qkv, training becomes more stable.
Without Alibi, clipping qkv is still very stable (it just doesn't do as well)
I tried clipping values of
{0.1, 1, 2, 10}
;{0.1, 1}
were too aggressive;{2, 10}
are shown. We can see that when the clipping value is higher (10
) the network is slightly less stable and gets psuedo-loss spikes from which it recovers gracefully.ClipQKV+Alibi outperforms QKLN
ClipQKV+Alibi has the exact same performance as QKLN+Alibi