Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA: assert when using batch size less than 129 #5140

Closed
ikawrakow opened this issue Jan 26, 2024 · 13 comments
Closed

CUDA: assert when using batch size less than 129 #5140

ikawrakow opened this issue Jan 26, 2024 · 13 comments
Assignees
Labels
bug Something isn't working stale

Comments

@ikawrakow
Copy link
Contributor

The model fully fits in VRAM with a lot of room to spare. Running on a 16 GB GPU, Mistral-7B quantized with any quantization:

./perplexity -m <model> -f test.wiki.raw -t 1 -ngl 100 -b 64
main: build = 1971 (1182cf4d)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1706267800
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 CUDA devices:
  Device 0: NVIDIA GeForce RTX 4080, compute capability 8.9, VMM: yes
llama_model_loader: loaded meta data with 20 key-value pairs and 291 tensors from junk2.bin (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.name str              = models
llama_model_loader: - kv   2:                       llama.context_length u32              = 32768
llama_model_loader: - kv   3:                     llama.embedding_length u32              = 4096
llama_model_loader: - kv   4:                          llama.block_count u32              = 32
llama_model_loader: - kv   5:                  llama.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 llama.rope.dimension_count u32              = 128
llama_model_loader: - kv   7:                 llama.attention.head_count u32              = 32
llama_model_loader: - kv   8:              llama.attention.head_count_kv u32              = 8
llama_model_loader: - kv   9:     llama.attention.layer_norm_rms_epsilon f32              = 0.000010
llama_model_loader: - kv  10:                       llama.rope.freq_base f32              = 10000.000000
llama_model_loader: - kv  11:                          general.file_type u32              = 2
llama_model_loader: - kv  12:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  13:                      tokenizer.ggml.tokens arr[str,32000]   = ["<unk>", "<s>", "</s>", "<0x00>", "<...
llama_model_loader: - kv  14:                      tokenizer.ggml.scores arr[f32,32000]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  15:                  tokenizer.ggml.token_type arr[i32,32000]   = [2, 3, 3, 6, 6, 6, 6, 6, 6, 6, 6, 6, ...
llama_model_loader: - kv  16:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  17:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  18:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  19:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:   65 tensors
llama_model_loader: - type q4_0:  225 tensors
llama_model_loader: - type q6_K:    1 tensors
llm_load_vocab: special tokens definition check successful ( 259/32000 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 32000
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 32768
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 8
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 4
llm_load_print_meta: n_embd_k_gqa     = 1024
llm_load_print_meta: n_embd_v_gqa     = 1024
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 32768
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 7B
llm_load_print_meta: model ftype      = Q4_0
llm_load_print_meta: model params     = 7.24 B
llm_load_print_meta: model size       = 3.83 GiB (4.54 BPW) 
llm_load_print_meta: general.name     = models
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.22 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:      CUDA0 buffer size =  3847.55 MiB
..................................................................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_kv_cache_init:      CUDA0 KV buffer size =    64.00 MiB
llama_new_context_with_model: KV self size  =   64.00 MiB, K (f16):   32.00 MiB, V (f16):   32.00 MiB
llama_new_context_with_model:  CUDA_Host input buffer size   =     0.56 MiB
llama_new_context_with_model:      CUDA0 compute buffer size =     4.56 MiB
llama_new_context_with_model:  CUDA_Host compute buffer size =     0.50 MiB
llama_new_context_with_model: graph splits (measure): 3

system_info: n_threads = 1 / 32 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | 
perplexity: tokenizing the input ..
perplexity: tokenization took 561.854 ms
perplexity: calculating perplexity over 642 chunks, batch_size=32
ggml_tallocr_alloc: not enough space in the buffer to allocate ffn_up-0 (needed 1835008, largest block available 1359872)
GGML_ASSERT: /home/iwan/other/llama.cpp/ggml-alloc.c:114: !"not enough space in the buffer"
Could not attach to process.  If your uid matches the uid of the target
process, check the setting of /proc/sys/kernel/yama/ptrace_scope, or try
again as the root user.  For more details, see /etc/sysctl.d/10-ptrace.conf
ptrace: Operation not permitted.
No stack.
The program is not being run.
Aborted (core dumped)

This is reproducible for any batch size <= 128.

@ikawrakow ikawrakow added the bug Something isn't working label Jan 26, 2024
@ggerganov
Copy link
Owner

Probably related to the following issue: #5086

@slaren
Copy link
Collaborator

slaren commented Jan 26, 2024

As I noted in that PR, this is a fundamental issue with the way ggml-alloc works when the sizes of the tensors don't match exactly the sizes used to measure the buffer size, and fixing this definitely will require significant changes. I outlined the solution here:

In the long term, a more robust solution is needed, such as always assigning the same offset within the buffer to the tensors, regardless of their size, then it would always work as long as the tensors are never larger than during measure. This should also make ggml-alloc faster during inference since we could skip the whole allocation process and simply reuse the same allocations obtained during measure, and maybe could allow for a more exhaustive search for a more optimal way to allocate tensor during measure, since it would only happen during initialization.

This will require changing the way tensors are allocated, such that they are always assigned the same memory addresses regardless of their size. This will also introduce the restriction that all the graphs must have the exact same topology than the graph used to measure the buffer size, and that the tensor sizes must never be larger than the sizes in the measure graph. A consequence of that is that the way the K-shift is handled in llama.cpp will need to be changed.

For now, if this use case is important, what we can do is add some margin to the sizes of the compute buffers (5-10% should do it). However it is impossible for me test every combination of model, batch size, and context size, all of which affect the sizes of the tensors used in the graph.

@ggerganov
Copy link
Owner

Yeah, it seems like a very tricky problem. Surprisingly, the strategy of greedily assigning the tensors to the best-fit block (i.e. least amount of bytes wasted) does not always lead to a solution because AFAICT it depends if a block will be reused later on or not.

Somehow it feels like there should be an elegant solution, but I can't see it.

Would 5-10% increase suffice? In this case it's more like 35% needed.

@slaren
Copy link
Collaborator

slaren commented Jan 26, 2024

Would 5-10% increase suffice? In this case it's more like 35% needed.

Hard to tell if it is going to be enough in every case, but in this case 10% seems enough. The change in size may affect previous allocations which results in less fragmentation (closer to the fragmentation observed during measure).

diff --git a/ggml-alloc.c b/ggml-alloc.c
index 60141a34..c9dc20b7 100644
--- a/ggml-alloc.c
+++ b/ggml-alloc.c
@@ -335,7 +335,7 @@ bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
 }

 size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
-    return alloc->max_size;
+    return alloc->max_size + alloc->max_size/10;
 }

 // graph allocator

@ggerganov
Copy link
Owner

In the few cases that I tried, it does solve the problem. Let's push this for now and make note to resolve this in the future

@slaren
Copy link
Collaborator

slaren commented Jan 26, 2024

Ok, do you want me to open a PR?

@ggerganov
Copy link
Owner

Yes, go ahead

@slaren
Copy link
Collaborator

slaren commented Jan 26, 2024

It's also possible that expanding the graphs in different topological order to minimize the number of tensors allocated at any moment could reduce fragmentation and alleviate this issue. I experimented with this in ggerganov/ggml#462 (comment). Ultimately, I think that the only way to provide a strong guarantee that allocations will never fail is to add the restrictions mentioned earlier (same graph topology and every tensors is never bigger than in the measure graph). We cannot implement very complicated logic neither in the graph expansion or in the tensor allocations because this is something that needs to be very fast, since it is done for every evaluation.

@Artefact2
Copy link
Collaborator

I can still hit the issue using e76627b.

% ./llama-bench -m ../models/Chronomaid-Storytelling-13b-Q4_K_S.gguf -b 128    
ggml_init_cublas: GGML_CUDA_FORCE_MMQ:   no
ggml_init_cublas: CUDA_USE_TENSOR_CORES: yes
ggml_init_cublas: found 1 ROCm devices:
  Device 0: AMD Radeon RX 6750 XT, compute capability 10.3, VMM: no
| model                          |       size |     params | backend    | ngl |    n_batch | test       |              t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ---------: | ---------- | ---------------: |
ggml_tallocr_alloc: not enough space in the buffer to allocate ffn_up-0 (needed 7077888, largest block available 6619187)
GGML_ASSERT: ggml-alloc.c:114: !"not enough space in the buffer"

@RodolfoCastanheira
Copy link

The temporary fix of #5149 was removed with PR #5452, and I am geting this bug. Reapling the extra 10% solved my problem.

@slaren
Copy link
Collaborator

slaren commented Feb 14, 2024

@RodolfoCastanheira that's simply not possible, the error conditions referenced here no longer exist. Please open a new issue and explain what you are observing in detail.

@RodolfoCastanheira
Copy link

You're right, I tested with an older build, my bad.

Copy link
Contributor

This issue is stale because it has been open for 30 days with no activity.

@github-actions github-actions bot added the stale label Mar 18, 2024
@slaren slaren closed this as completed Mar 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale
Projects
None yet
Development

No branches or pull requests

5 participants