Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for BitnetForCausalLM (new model / new datatype) #7931

Merged
merged 38 commits into from
Jun 23, 2024

Conversation

Eddie-Wang1120
Copy link
Contributor

@Eddie-Wang1120 Eddie-Wang1120 commented Jun 14, 2024

  • Self Reported Review Complexity:
    • Review Complexity : Low
    • Review Complexity : Medium
    • Review Complexity : High
  • I have read the contributing guidelines

PR Intro

This PR is to support BitnetForCausalLM for llama.cpp, which includes several points:

  • add new BitnetForCausalLM architecture
  • add new tensors FFN_SUB_NORM / ATTN_SUB_NORM (special tensors used in Bitnet FFN/ATTN)
  • add support for BitnetForCausalLM model conversion in convert-hf
  • add support for inference for models based on BitnetForCausalLM
  • add new datatype I2_S / I8_S (deprecated)
  • add new datatype Q2_2 (deprecated)
  • add Q2_2 quantization / matmul kernel (deprecated)

Note

This PR only contains BitNet model support.

Q2_2 / I2_S and I8_S are deprecated now, you can still try it by checkout to commit

Also many thanks to @compilade for a new 1.625bpw datatype Q1_3, can be found in compilade/bitnet-ternary

How to use Q2_2?

  • download model from BitnetForCausalLM
  • python convert-hf-to-gguf.py [model_path] --outtype f32
  • make quantize && make main
  • ./quantize [model_path]/ggml-model-f32.gguf [model_path]/ggml-model-q2_2.gguf Q2_2
  • ./main -m [model_path]/ggml-model-q2_2.gguf

Q2_2 Results

  model_size(MB) output_weight norm_weight other_weight input ppl
f16 6809.6 fp16 fp32 fp16 fp16 8.5419 +/- 0.05239
q4_0 1833.74 q8_0 fp32 q4_0 q8_0 8.7610 +/- 0.05444
iq4_nl 1848.57 q8_0 fp32 iq4_nl q8_0 8.5797 +/- 0.0289
q2_2 873.66 q8_0 fp32 q2_2 q8_0 8.5495 +/- 0.05248

(test by llama.cpp on wikitext-2)

  threads test t/s
f16 8 tgl128 7.29 ± 0.04
q4_0 8 tgl128 16.45 ± 0.03
iq4_nl 8 tgl128 19.41 ± 0.06
q2_2 8 tgl128 22.34 ± 0.82

(test by llama.cpp llama-bench on 12th Gen Intel(R) Core(TM) i5-12500H)

Why add I2_S and I8_S?

Bitnet does not use per-channel but per-tensor quantization both for activation (int8) and weight (1, 0, -1). This means each activation or weight for special matmul operations (attn_q / attn_k / attn_v / attn_o / ffn_up / ffn_gate / ffn_down) only has one scale. However, it seems that quantization types in llama.cpp all use block as basic unit, which is suitable for per-channel quantization, but not working with per-tensor quantization. In this case, I designed two new datatypes for 2bit and 8bit per-tensor quantization respectively, called I2_s and I8_s, can solve the problem.

How to use I2_S and I8_S?

  • git checkout 569a03e
  • download model from BitnetForCausalLM
  • python convert-hf-to-gguf.py [model_path] --outtype f32
  • make quantize && make main
  • ./quantize [model_path]/ggml-model-f32.gguf [model_path]/ggml-model-i2_s.gguf I2_S 1
  • ./main -m [model_path]/ggml-model-i2_s.gguf

I2_S I8_S Results

  model_size(MB) output_weight norm_weight other_weight input ppl
f16 6809.6 fp16 fp32 fp16 fp16 8.5419 +/- 0.05239
q4_0 1833.74 q8_0 fp32 q4_0 q8_0 8.7610 +/- 0.05444
iq4_nl 1848.57 q8_0 fp32 iq4_nl q8_0 8.5797 +/- 0.0289
i2_s 873.66 q8_0 fp32 i2_s i8_s 8.5535 +/- 0.05268

(test by llama.cpp on wikitext-2)

  threads test t/s
f16 8 tgl128 6.83 ± 0.03
q4_0 8 tgl128 19.30 ± 0.06
iq4_nl 8 tgl128 22.03 ± 0.04
i2_s 8 tgl128 26.35 ± 0.06

(test by llama.cpp llama-bench on 13th Gen Intel(R) Core(TM) i5-13400F)

I2_S has lower ppl with model size more than twice as small as q4_0 and iq4_nl, also has a inference speed improvements than q4_0 and iq4_nl.

Questions

Will llama.cpp support non-block quantization datatype? @ggerganov I tried my best but new datatype can't merge into llama.cpp without special judgment (src0->type == GGML_TYPE_I2_S). It would be so great if llama.cpp could support it.

TODO

  • support SIMD in i2s_i8s vec_dot kernel (x86)
  • fix non-block special judgment issue

@github-actions github-actions bot added examples python python script changes ggml changes relating to the ggml tensor library for machine learning labels Jun 14, 2024
convert-hf-to-gguf.py Outdated Show resolved Hide resolved
@bartowski1182
Copy link
Contributor

Am I understanding correctly that these new quant types (I2_S, I8_S) will ONLY work with bitnet models, and not across all models?

The code itself doesn't imply that (but also doesn't include I8_S in QUANT_OPTIONS) so just want to clarify

@Dampfinchen
Copy link

Oh god, it's finally happening. You are doing the lord's work! Super excited for this.

@JackCloudman
Copy link

JackCloudman commented Jun 14, 2024

First, good work! I've been trying out the build with CUDA support and encountered an error. Here are the steps I followed and the results:

  • Successfully compiled quantize and main using make LLAMA_CUDA=1 CUDA_DOCKER_ARCH=all.
  • Converted and quantized the model from https://huggingface.co/1bitLLM/bitnet_b1_58-3B into f32.
  • Attempted to run the compiled binary with ./main -m 1bit-i2_s.gguf -ngl 1.

However, when executing the main command, the following error was produced:

llama_new_context_with_model: n_ctx = 2048
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init: CUDA_Host KV buffer size = 625.00 MiB
llama_kv_cache_init: CUDA0 KV buffer size = 25.00 MiB
llama_new_context_with_model: KV self size = 650.00 MiB, K (f16): 325.00 MiB, V (f16): 325.00 MiB
llama_new_context_with_model: CUDA_Host output buffer size = 0.12 MiB
GGML_ASSERT: ggml.c:3613: view_src == NULL || data_size == 0 || data_size + view_offs <= ggml_nbytes(view_src)
Could not attach to process. If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user. For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.

========

Interestingly, when compiling and running with just CPU support, everything works fine. It seems like there might be an issue specifically related to CUDA integration. Any insights or help would be greatly appreciated!

@Eddie-Wang1120
Copy link
Contributor Author

you should probably try to pull in the changes form master, the binary names where changed, so test will naturally fail.

Thanks for the advice, already merge into the master and fix the whitespace.

Copy link
Collaborator

@slaren slaren left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx_split should only be used for matrices.

llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
llama.cpp Outdated Show resolved Hide resolved
@Eddie-Wang1120
Copy link
Contributor Author

After changing a little bit of the llm_build_kqv() code, now BitNet can reuse the llm_build_kv() to construct the model, please review.

@gonzalo-santamaria-iic
Copy link

gonzalo-santamaria-iic commented Jun 21, 2024

Hello! I don't know if this BitNet-b1.58 is a good reproduction of the archiecture proposed in the original research. They said in pag 5 section 3 that all RMSNorm before Attention and SwiGLU (MLP?) should be removed, but it seems that both layers are still present in the decoder block:

hidden_states = self.input_layernorm(hidden_states)

and

hidden_states = self.post_attention_layernorm(hidden_states)

Furthermore, it is not entirely clear to me whether this RMSNorm should be a parameter-free layer, otherwise this would create conflicts with the inference proposed on pag 6 section 3, since in this case there is only one identical RMSNorm for all. I don't know if this is entirely true, I left an issue in the HuggingFace thread to get the doubt solved.

Anyone who can help clarify this? Thanks for the great work you are doing to integrate new technologies into useful libraries like llama.cpp 😄

@Eddie-Wang1120

This comment was marked as resolved.

@flatsiedatsie
Copy link

I don't know if this BitNet-b1.58 is a good reproduction of the archiecture proposed

Does it matter? Does implementing this flavour limit a future implementation of the originally proposed version?

"if it works, it ain't stupid"


// input for next layer
inpL = cur;
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe should extend llm_build_ffn() to support _scale tensors and reuse it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried that if I change the llm_build_ffn api, for example adding ffn_gate_scale / ffn_up_scale / ffn_down_scale to the function parameters, than I have to change the code for all models which use llm_build_ffn, seems not the things I should do in this PR. If supporting _scales tensor is neccessary, I can contribute a new PR and make this change suits all model after this PR merged.

llama.cpp Outdated
cb(Kcur, "Kcur", il);

cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
nullptr, model.layers[il].bo,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing model.layers[il].bo here seems incorrect. I think it should be added below after the projection block:

cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
cb(cur, "attn_o_out", il);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed the code about model.layers[il].bo may cause some misunderstanding of BitNet, even though at this time model.layers[il].bo is nullptr. Already change it in lateset commit, please review. @ggerganov

@flatsiedatsie
Copy link

flatsiedatsie commented Jun 23, 2024

This is probably not relevant, but I just excitedly tried to compile Lllama.cpp from this PR through Wllama (Emscriptem/Wasm), to see if I could run "Bitnet in the Browser".

I tried the ggml-model-q2_k-pad.gguf BitNet model that GGerganov provided.

Unfortunately I got this error:

llama_model_load: error loading model: check_tensor_dims: tensor 'token_embd.weight' has wrong shape; expected  3200, 32002, got  3328, 32002,     1,     1

I also tried bitnet_b1_58-3B.q2_2.gguf and saw a more general crash.

I then tried to compile the current Llama.cpp version from a few minutes ago, without the Wllama wrapper.

But then I saw an error too, the same one on both models:

Screenshot 2024-06-23 at 18 31 25

Mac OS, M1 pro

@ggerganov ggerganov merged commit e112b61 into ggerganov:master Jun 23, 2024
65 checks passed
@Dampfinchen
Copy link

Dampfinchen commented Jun 24, 2024

Truly q2_2 looks absolutely insane. Splendid work!

I do wonder if that also makes fine tuning models on lower end hardware possible, as the quantized models were previously not of high enough quality to fine tune on and fp16 had to be used instead.

I think this is a great opportunity to change the need for fp16 models when it comes to training models.

Nexesenex pushed a commit to Nexesenex/croco.cpp that referenced this pull request Jun 24, 2024
* hf bitnet v1

* hf bitnet e2e v2

* finish bitnet e2e

* finish f16 hf bitnet e2e

* remove unsed

* finish bitnet i2 e2e

* move i2s to quantize v1

* move i2 to quantize

* clean code

* clean code 2

* fix codestyle

* fix code

* fix

* fix code

* fix merge

* remove unused

* change table name

* fix whitespace

* delete redundant

* i2_s to absmax

* finish i2_s/i8_s vec_dot x86 simd

* i2s->q22

* fix code

* remove block scale

* add dequantize

* fix seq

* update avx2

* remove q2_2

* remove q22_grid

* fix whitespace

* reuse llm_build_kv

* fix bo

---------

Co-authored-by: root <root@wangjinheng>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Jun 30, 2024
* hf bitnet v1

* hf bitnet e2e v2

* finish bitnet e2e

* finish f16 hf bitnet e2e

* remove unsed

* finish bitnet i2 e2e

* move i2s to quantize v1

* move i2 to quantize

* clean code

* clean code 2

* fix codestyle

* fix code

* fix

* fix code

* fix merge

* remove unused

* change table name

* fix whitespace

* delete redundant

* i2_s to absmax

* finish i2_s/i8_s vec_dot x86 simd

* i2s->q22

* fix code

* remove block scale

* add dequantize

* fix seq

* update avx2

* remove q2_2

* remove q22_grid

* fix whitespace

* reuse llm_build_kv

* fix bo

---------

Co-authored-by: root <root@wangjinheng>
MagnusS0 pushed a commit to MagnusS0/llama.cpp-normistral-tokenizer that referenced this pull request Jul 1, 2024
* hf bitnet v1

* hf bitnet e2e v2

* finish bitnet e2e

* finish f16 hf bitnet e2e

* remove unsed

* finish bitnet i2 e2e

* move i2s to quantize v1

* move i2 to quantize

* clean code

* clean code 2

* fix codestyle

* fix code

* fix

* fix code

* fix merge

* remove unused

* change table name

* fix whitespace

* delete redundant

* i2_s to absmax

* finish i2_s/i8_s vec_dot x86 simd

* i2s->q22

* fix code

* remove block scale

* add dequantize

* fix seq

* update avx2

* remove q2_2

* remove q22_grid

* fix whitespace

* reuse llm_build_kv

* fix bo

---------

Co-authored-by: root <root@wangjinheng>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
examples ggml changes relating to the ggml tensor library for machine learning python python script changes Tensor Encoding Scheme https://github.com/ggerganov/llama.cpp/wiki/Tensor-Encoding-Schemes
Projects
None yet
Development

Successfully merging this pull request may close these issues.