-
Notifications
You must be signed in to change notification settings - Fork 253
int8wo Embedding Quant #1167
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
int8wo Embedding Quant #1167
Conversation
HDCharles
commented
Oct 25, 2024
•
edited
Loading
edited
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1167
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit e251e67 with merge base 4b563f2 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
573f863
to
3149b5b
Compare
3149b5b
to
e9205db
Compare
e9205db
to
cee8235
Compare
idx = args[0] | ||
int_data, scale, zero_point = args[1].tensor_impl.get_plain() | ||
assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 | ||
sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there any restrictions on idx
for this to be valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not as far as our tests show
Summary: Added int8 embedding quant to torchAO, speeds up inference on our llama benchmark from 107.8 -> 108.5 tok/s on A100 expected api is quantize_(model, int8_weight_only(group_size=64), filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding)) Test Plan: python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --quantization embed-int8wo --compile python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile python test_integration.py -k "test_weight_only_groupwise_embedding_quant" Reviewers: Subscribers: Tasks: Tags:
cee8235
to
e251e67
Compare