Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Script to convert Grok-1 weights from raw JAX pickle files. #7058

Open
wants to merge 13 commits into
base: master
Choose a base branch
from

Conversation

heiner
Copy link

@heiner heiner commented May 3, 2024

This adds a script to convert the raw weights in the pickle files to GGUF format. This allows using @arki05's work in #6204 directly from the Grok-1 torrent.

Code is based on @foldl's conversion script in chatllm.cpp, which in turn is based on @chu-tianxiang's gist.

Main ideas to avoid excessive memory:

  • Parse pickle files using mmap.
  • Use PyTorch "meta" tensors to simulate shape and dtype results without having to do all conversions beforehand.

Note that I couldn't run the full model due to RAM constrains and it's possible I mixed up some tensor names.

@slaren
Copy link
Collaborator

slaren commented May 3, 2024

Does this merge the experts into a single tensor?

@heiner
Copy link
Author

heiner commented May 3, 2024

Does this merge the experts into a single tensor?

It does the opposite -- in the raw data, the 8 experts are part of the same tensor. This splits them, which is also what the chatllm.cpp script does.

If there is a way to keep them within one tensor I'm happy to make that change.

@slaren
Copy link
Collaborator

slaren commented May 3, 2024

The preferred way to export the expert tensors is as a single 3D tensor for all the experts. It is still possible to use one tensor per expert for backwards compatibility, but it forces the model weights to be copied to a buffer while loading, rather than using them directly from the memory mapped file. For large models like grok, I think it is especially important to be able to avoid this copy and use mmap.

@heiner
Copy link
Author

heiner commented May 3, 2024

Understood. That will actually make the script simpler. Would you happen to know the tensor names I should use in this case? Currently when using splitting, they are

| blk.{layer}.ffn_gate.{expert}.weight        | torch.Size([32768, 6144])  | Q4_0    |
| blk.{layer}.ffn_down.{expert}.weight        | torch.Size([6144, 32768])  | Q4_0    |
| blk.{layer}.ffn_up.{expert}.weight          | torch.Size([32768, 6144])  | Q4_0    |

@slaren
Copy link
Collaborator

slaren commented May 3, 2024

The tensor names are defined in gguf-py:

MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",

It would be good to use these constants rather than hardcoding the names.

heiner added a commit to heiner/llama.cpp that referenced this pull request May 3, 2024
As per ggerganov#7058 (comment).
This helps avoid a memcopy when running.
@heiner
Copy link
Author

heiner commented May 3, 2024

Thanks!

I have updated the branch to no longer split the MoE weights into separate tensors. That simplifies the script as it's now one weight per file. The original script permutated the order in which these weights are written for some reason, I stopped doing that now and thus there's only one list of weight names.

I also moved to the values in the gguf.TENSOR_NAMES dict as per your suggestion. I'm not sure that's a clear improvement ("Explicit is better than implicit."), especially in view of code like name.endswith("attn_k") and name.endswith("_exps") but it's also not much worse.

PTAL.

@foldl
Copy link
Contributor

foldl commented May 3, 2024

@heiner, name of my project is ChatLLM.cpp, not ChatLLM.ccp, 😄

@mofosyne mofosyne added the python python script changes label May 9, 2024
ggerganov
ggerganov previously approved these changes May 9, 2024
Copy link
Owner

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge after lint fixes

@mofosyne mofosyne added the Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level label May 9, 2024
@ggerganov
Copy link
Owner

Hm, I tested Q4_0 conversion and it does not seem to work:

python3 convert_grok.py -i ~/Data/huggingface/grok-1/ckpt/ --vocab_dir ../grok-1 -o x.gguf -t q4_0

make -j && ./main -m x.gguf -p "I believe the meaning of life is" -ngl 0

...

[BOS] I believe the meaning of life is it:000000000000000 of0000000000000000000000000000000000000000000000

Might need more work

@ggerganov ggerganov marked this pull request as draft May 9, 2024 14:03
@ggerganov ggerganov dismissed their stale review May 9, 2024 14:03

Did not work on my machine

convert_grok.py Outdated Show resolved Hide resolved
convert_grok.py Show resolved Hide resolved
@heiner
Copy link
Author

heiner commented May 9, 2024

[BOS] I believe the meaning of life is it:000000000000000 of0000000000000000000000000000000000000000000000

My apologies. As I said above I couldn't actually test running the full model on my setup. I will fix @foldl's suggestions.

Would you happen to have something like the sha-1 of each tensor of a checkpoint based on the HF weights? Otherwise I can download those and run that conversion for comparision.

@mofosyne mofosyne added the enhancement New feature or request label May 9, 2024
@heiner
Copy link
Author

heiner commented May 9, 2024

Thanks @foldl for the hints. It's well possible I mixed something else up as well, e.g., swapped two tensors with the same shape and dtype. Would you happen to have a tensor name -> hash table for a correct conversion?

@foldl
Copy link
Contributor

foldl commented May 9, 2024

@heiner You need to compare the result against #6204, although I don't think sha-1 could match.

chatllm.cpp permutes k_proj/q_proj weights, so, sha-1 would not match, either.

@heiner
Copy link
Author

heiner commented May 10, 2024

Thanks. I have removed the multiplication with embedding_multiplier_scale again and converted the full model with all 8 experts. The output is not great but also not as bad as before (gist with full output):

./build/bin/main -m grok.bin -p "I believe the meaning of life is" -s 2 -n 10 -ngl 0
(...)
[BOS] I believe the meaning of life is important it could possibly the general of the general I

It's likely something else is wrong but I'm unsure what it is, and the multiple-hour iteration time makes it infeasible to just try out random things.

@heiner heiner force-pushed the master branch 2 times, most recently from ee5921e to cd38c87 Compare May 23, 2024 09:12
heiner added a commit to heiner/llama.cpp that referenced this pull request May 23, 2024
As per ggerganov#7058 (comment).
This helps avoid a memcopy when running.
@heiner
Copy link
Author

heiner commented May 23, 2024

I added two more fixes.

I then compared the output of this PR with Arki05/Grok-1-GGUF/Q8_0 from HF via this script.

All tensors are exactly the same now.

(I have to np.stack the expert tensors from this download on axis 0, as they are split there.)

The changes:

  1. The Grok-1 pickle files use lexicographic order (0 < 1 < 10 < 11 ... < 2 < 20 < ...) so the layers were incorrectly ordered.

  2. PyTorch's rounding mode is not away from zero in halfway cases, unlike roundf(3). This made for a difference of 1 for ~0.1% of the int8 entries compared to Arki05/Grok-1-GGUF/Q8_0. Using the new gguf.quantize_q8_0 fixes this in the case of Q8_0 (at the cost of increased conversion time).
    Edit: Thanks to @compilade I now added a PyTorch version of quantize_q8_0 (PyTorch is useful here since it allows to figure out shapes and dtypes via meta tensors).

Unfortunately, I cannot run the Arki05/Grok-1-GGUF/Q8_0 weights on my MacBook as it OOMs. I can run a two-expert version of this PR (very slowly, several minutes per token), but the output is not great:

$ ./build/bin/main -m grok.bin -p "The answer to life the universe and everything is" -s 1 -n 4 -ngl 2
...
[BOS] The answer to life the universe and everything is gifted for the of

Could someone with the right hardware run Arki05/Grok-1-GGUF/Q8_0 and see if it's any better? If it is, perhaps I missed some header setting (I didn't see any difference that seemed relevant). Otherwise, I believe this conversion is as good as the quantization supports?

@heiner heiner marked this pull request as ready for review May 23, 2024 13:42
convert_grok.py Outdated Show resolved Hide resolved
@foldl
Copy link
Contributor

foldl commented May 24, 2024

The Grok-1 pickle files use lexicographic order ...

Nice catch, 😄

@heiner
Copy link
Author

heiner commented May 25, 2024

(The Docker image test failure is unrelated to the changes in this PR.)

@mofosyne
Copy link
Collaborator

(The Docker image test failure is unrelated to the changes in this PR.)

Yup, it's bypassed for now. Please rebase on top of current master

@heiner
Copy link
Author

heiner commented May 25, 2024

(The Docker image test failure is unrelated to the changes in this PR.)

Yup, it's bypassed for now. Please rebase on top of current master

Done.

@ggerganov
Copy link
Owner

Does the conversion work correctly now? I can run some tests if you need confirmation?

@heiner
Copy link
Author

heiner commented May 27, 2024

The actual generations still don't look great, e.g. for Q8_0:

$ ./build/bin/main -m grok.bin -p "The answer to life the universe and everything is" -s 1 -n 4 -ngl 2
...
[BOS] The answer to life the universe and everything is gifted for the of

What would be very useful is if you could run both this (in Q8_0) as well as Arki05/Grok-1-GGUF/Q8_0 and let me know if there's any difference in the output, and if so, what plausibly could cause that. I have verified that the actual weights in both cases are the same, so any difference would presumably be kv settings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants