Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformers save/load compatibility and inference kernels #3

Merged
merged 111 commits into from
Feb 7, 2024

Conversation

BlackSamorez
Copy link
Collaborator

@BlackSamorez BlackSamorez commented Jan 16, 2024

This PR.

  • Merges QuantizedWeight and QuantizedLinear into one class
  • Decouples QuantizedLinear creation with empty weights and QuantizedLinear initialisation with KMeans
  • Changes saving to save only the final state dict
  • Adds custom modeling_llama.py, allowing to load the saved state dict with load_pretrained
  • Adds conversion script for previously saved models
  • Adds custom matmul kernel written in triton (maybe separate into a different PR)

@BlackSamorez BlackSamorez force-pushed the transformers branch 2 times, most recently from f366016 to f498eaf Compare January 16, 2024 11:49
src/aq.py Outdated

def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.reconstruct_weight(), self.bias)
# original_shape = input.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

commented code; do we need it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

src/aq.py Outdated
@@ -153,7 +209,7 @@ def get_scales(self) -> torch.Tensor:
else: # train scale codebook only
return self.scales_clusters.gather(1, self.scales_indices)[:, :, None, None]

def forward(self, selection: Union[slice, ellipsis, torch.Tensor] = ...):
def reconstruct_weight(self, selection: Union[slice, ellipsis, torch.Tensor] = ...):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd request that we keep it as "forward" for notebook and experiment compatibility.
Alternatively,

def forward(self, *args, **kwargs): self.reconstruct_weight(*args, **kwargs)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make sense for this layer to behave like a normal nn.Linear layer. That is, actually perform a forward pass on forward. That way we won't have to change much code when replacing nn.Linears with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I get it now. Maybe best to move that to QuantizedLinear class and keep QuantizedWeight as is?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done more or less

Copy link
Collaborator

@justheuristic justheuristic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work!

While we wait for other reviews, can you please run an experiment on this branch to check that everything works and achieves ~same perplexity?

@BlackSamorez BlackSamorez force-pushed the transformers branch 2 times, most recently from 2f8a838 to 0647577 Compare January 16, 2024 18:32
@BlackSamorez BlackSamorez changed the title [WIP] Transformers save/load compatibility and inference kernels Transformers save/load compatibility and inference kernels Jan 20, 2024
@AlexKoff88
Copy link

It looks awesome really!

I wonder if you guys consider extending your work to the case when there can be multiple sets of lookups for one weight tensor. I mean having something like: two additive 8-bit lookups for a tile of 32x4096 weights. The idea is to make it more HW-friendly so that both lookups can be stored in the shared memory of GPU and weights can be quickly prefetched before the MatMul. 8-bit + 8-bit = 16-bit scheme to represent 8 weights should be also HW-friendly both from the unpacking and storage point of view.

What do you think?

@BlackSamorez
Copy link
Collaborator Author

@AlexKoff88 Thanks!
We're working on optimizing inference for both GPUs and CPUs, and, indeed, having smaller codebooks has the potential to greatly improve the performance. We'll make sure to publish the code once we have reliable results in terms of both model compression quality and inference speed, so stay tuned!

from aqlm.utils import get_int_dtype


class QuantizedLinear(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please create an issue, for later, to examine and potentially deduplicate the code here,

with a full understanding that we're about as likely to fix it as to never open it

To save you time, here's one possible issue text:

In the current version, we're reusing some of the code between `src` and `inference_lib/src`. For instance, inference_lib/src/inference.py:QuantizedLinear resembles src/aq.py:QuantizedLinear.

If we have time, it would be nice to selectively merge some of them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My idea was that inference code should be completely separated from the quantisation code not to break the latter and because there isn't much overlap anyway. Those two classes serve very different purpose and share surprisingly little code.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @BlackSamorez on this one.

@@ -0,0 +1,61 @@
#include <torch/all.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strong opinion: we need compilation instructions OR a promise that you'll add them with a fixed deadline.

If i missed the instructions somewhere, please direct me to them

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It compiles in runtime here no deliberate compilation actions needed

Copy link
Collaborator

@justheuristic justheuristic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

In future, it would be nice to minimize the amount of code we copy from transformers, but that can wait.

You did a gargantuan amount of work here :)

BlackSamorez and others added 13 commits February 6, 2024 22:50
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Co-authored-by: justheuristic <justheuristic@gmail.com>
Copy link
Owner

@Vahe1994 Vahe1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!
I think the last major thing that is missing - instruction in Readme.md.
Would be nice to have there :

  1. List of models in HF
  2. How to and what to install.
  3. Instruction how to interact with the code (referring to colab notebook)

README.md Outdated
| Mixtral-8x7b| 1x15 | 4.61 | 12.6 | [Link](https://huggingface.co/BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x15-hf)|


### Dependencies
Copy link
Collaborator

@justheuristic justheuristic Feb 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: i'd still call this installation

why: there's another "dependencies" later, which could cause confusion

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@BlackSamorez BlackSamorez merged commit e1292e2 into main Feb 7, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants