-
Notifications
You must be signed in to change notification settings - Fork 183
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformers save/load compatibility and inference kernels #3
Conversation
f366016
to
f498eaf
Compare
57c2462
to
98fbddc
Compare
src/aq.py
Outdated
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
return F.linear(input, self.reconstruct_weight(), self.bias) | ||
# original_shape = input.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
commented code; do we need it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed
src/aq.py
Outdated
@@ -153,7 +209,7 @@ def get_scales(self) -> torch.Tensor: | |||
else: # train scale codebook only | |||
return self.scales_clusters.gather(1, self.scales_indices)[:, :, None, None] | |||
|
|||
def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...): | |||
def reconstruct_weight(self, selection: Union[slice, ellipsis, torch.Tensor] = ...): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd request that we keep it as "forward" for notebook and experiment compatibility.
Alternatively,
def forward(self, *args, **kwargs): self.reconstruct_weight(*args, **kwargs)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would make sense for this layer to behave like a normal nn.Linear
layer. That is, actually perform a forward pass on forward
. That way we won't have to change much code when replacing nn.Linear
s with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I get it now. Maybe best to move that to QuantizedLinear class and keep QuantizedWeight as is?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done more or less
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work!
While we wait for other reviews, can you please run an experiment on this branch to check that everything works and achieves ~same perplexity?
2f8a838
to
0647577
Compare
It looks awesome really! I wonder if you guys consider extending your work to the case when there can be multiple sets of lookups for one weight tensor. I mean having something like: two additive 8-bit lookups for a tile of 32x4096 weights. The idea is to make it more HW-friendly so that both lookups can be stored in the shared memory of GPU and weights can be quickly prefetched before the MatMul. 8-bit + 8-bit = 16-bit scheme to represent 8 weights should be also HW-friendly both from the unpacking and storage point of view. What do you think? |
@AlexKoff88 Thanks! |
from aqlm.utils import get_int_dtype | ||
|
||
|
||
class QuantizedLinear(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please create an issue, for later, to examine and potentially deduplicate the code here,
with a full understanding that we're about as likely to fix it as to never open it
To save you time, here's one possible issue text:
In the current version, we're reusing some of the code between `src` and `inference_lib/src`. For instance, inference_lib/src/inference.py:QuantizedLinear resembles src/aq.py:QuantizedLinear.
If we have time, it would be nice to selectively merge some of them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My idea was that inference code should be completely separated from the quantisation code not to break the latter and because there isn't much overlap anyway. Those two classes serve very different purpose and share surprisingly little code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @BlackSamorez on this one.
@@ -0,0 +1,61 @@ | |||
#include <torch/all.h> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strong opinion: we need compilation instructions OR a promise that you'll add them with a fixed deadline.
If i missed the instructions somewhere, please direct me to them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It compiles in runtime here no deliberate compilation actions needed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
In future, it would be nice to minimize the amount of code we copy from transformers, but that can wait.
You did a gargantuan amount of work here :)
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
I think the last major thing that is missing - instruction in Readme.md.
Would be nice to have there :
- List of models in HF
- How to and what to install.
- Instruction how to interact with the code (referring to colab notebook)
README.md
Outdated
| Mixtral-8x7b| 1x15 | 4.61 | 12.6 | [Link](https://huggingface.co/BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x15-hf)| | ||
|
||
|
||
### Dependencies |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: i'd still call this installation
why: there's another "dependencies" later, which could cause confusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
This PR.
QuantizedWeight
andQuantizedLinear
into one classQuantizedLinear
creation with empty weights andQuantizedLinear
initialisation with KMeansmodeling_llama.py
, allowing to load the saved state dict withload_pretrained