Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: add test for token probs #7347

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

I'm trying to figure out how to enable deterministic results for >1 slots. For this end I wrote a test that checks whether the probabilities produced by the server are bit-for-bit identical. Notably when using 4 slots and a single thread with the CPU backend the token probabilities beyond the first are not equal and I don't understand why. The explanation by @ggerganov in ggerganov/whisper.cpp#1941 (comment) doesn't make sense to me because to my understanding the same values should still be added in the same order with the CPU backend (except for the masked values that are 0.0f and whose addition should not affect rounding error).

@ggerganov
Copy link
Owner

I think the only source of variation on the CPU was from ggml_vec_dot_f16 used during attention. Disabling SIMD should make the results deterministic:

diff --git a/ggml.c b/ggml.c
index 55152bce..fe516a7c 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1726,7 +1726,7 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
 
     ggml_float sumf = 0.0;
 
-#if defined(GGML_SIMD)
+#if defined(GGML_SIMD_XXX)
     const int np = (n & ~(GGML_F16_STEP - 1));
 
     GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };

Now with SIMD-ified softmax, it is likely another source of non-determinism, so you might have to also:

diff --git a/ggml.c b/ggml.c
index 55152bce..4bbabea0 100644
--- a/ggml.c
+++ b/ggml.c
@@ -1726,7 +1726,7 @@ static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t *
 
     ggml_float sumf = 0.0;
 
-#if defined(GGML_SIMD)
+#if defined(GGML_SIMD_XXX)
     const int np = (n & ~(GGML_F16_STEP - 1));
 
     GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
@@ -2301,48 +2301,48 @@ static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
 static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
     int i = 0;
     ggml_float sum = 0;
-#if defined(__AVX512F__) && defined(__AVX512DQ__)
-    for (; i + 15 < n; i += 16) {
-        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
-                                               _mm512_set1_ps(max)));
-        _mm512_storeu_ps(y + i, val);
-        sum += (ggml_float)_mm512_reduce_add_ps(val);
-    }
-#elif defined(__AVX2__) && defined(__FMA__)
-    for (; i + 7 < n; i += 8) {
-        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
-                                               _mm256_set1_ps(max)));
-        _mm256_storeu_ps(y + i, val);
-        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
-                                 _mm256_castps256_ps128(val));
-        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
-        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
-        sum += (ggml_float)_mm_cvtss_f32(val2);
-    }
-#elif defined(__SSE2__)
-    for (; i + 3 < n; i += 4) {
-        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
-                                            _mm_set1_ps(max)));
-        _mm_storeu_ps(y + i, val);
-#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
-        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
-        val = _mm_add_ss(val, _mm_movehdup_ps(val));
-#else
-        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
-        val = _mm_add_ps(val, tmp);
-        tmp = _mm_movehl_ps(tmp, val);
-        val = _mm_add_ss(val, tmp);
-#endif
-        sum += (ggml_float)_mm_cvtss_f32(val);
-    }
-#elif defined(__ARM_NEON)
-    for (; i + 3 < n; i += 4) {
-        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
-                                                vdupq_n_f32(max)));
-        vst1q_f32(y + i, val);
-        sum += (ggml_float)vaddvq_f32(val);
-    }
-#endif
+//#if defined(__AVX512F__) && defined(__AVX512DQ__)
+//    for (; i + 15 < n; i += 16) {
+//        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
+//                                               _mm512_set1_ps(max)));
+//        _mm512_storeu_ps(y + i, val);
+//        sum += (ggml_float)_mm512_reduce_add_ps(val);
+//    }
+//#elif defined(__AVX2__) && defined(__FMA__)
+//    for (; i + 7 < n; i += 8) {
+//        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
+//                                               _mm256_set1_ps(max)));
+//        _mm256_storeu_ps(y + i, val);
+//        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
+//                                 _mm256_castps256_ps128(val));
+//        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
+//        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
+//        sum += (ggml_float)_mm_cvtss_f32(val2);
+//    }
+//#elif defined(__SSE2__)
+//    for (; i + 3 < n; i += 4) {
+//        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
+//                                            _mm_set1_ps(max)));
+//        _mm_storeu_ps(y + i, val);
+//#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
+//        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
+//        val = _mm_add_ss(val, _mm_movehdup_ps(val));
+//#else
+//        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
+//        val = _mm_add_ps(val, tmp);
+//        tmp = _mm_movehl_ps(tmp, val);
+//        val = _mm_add_ss(val, tmp);
+//#endif
+//        sum += (ggml_float)_mm_cvtss_f32(val);
+//    }
+//#elif defined(__ARM_NEON)
+//    for (; i + 3 < n; i += 4) {
+//        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
+//                                                vdupq_n_f32(max)));
+//        vst1q_f32(y + i, val);
+//        sum += (ggml_float)vaddvq_f32(val);
+//    }
+//#endif
     for (; i < n; ++i) {
         float val = expf(x[i] - max);
         sum += (ggml_float)val;

Do these changes yield identical results in your test?

@JohannesGaessler
Copy link
Collaborator Author

This PR was based on a commit prior to #7154 . If I rebase it to include that PR then the test case with a single parallel prompt also fails. If I apply both suggested patches this fixes the test case with a single parallel prompt but not the test case with 4 parallel prompts that was failing before.

@mofosyne mofosyne added testing Everything test related Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level enhancement New feature or request labels May 18, 2024
@ggerganov
Copy link
Owner

I guess llamafile's sgemm routines are another source of attention invariance due to F16 SIMD used there. Try building with LLAMA_NO_LLAMAFILE=1

@JohannesGaessler
Copy link
Collaborator Author

With the suggested patches and LLAMA_NO_LLAMAFILE=1 I get bit-for-bit identical results. I think the way to go forward with this PR is to just comment out the tests that currently fail on master and to link the discussion in this PR; my main goal is to assert that there isn't a bug somewhere that only causes slightly incorrect results.

For actual production use I'm not sure how to proceed. With CUDA at least you cannot guarantee bit-for-bit identical results with the current unified KV cache because unlike with the CPU backend the insertion of masked values between the relevant changes the rounding error because you have 32 threads calculating partial sums in parallel. I think this could be fixed if the unified KV cache were to allocate memory to sequences in chunks of at least 64 values (64 because the optimal memory access patterns for the 32 threads in a CUDA warp is to read in 64 FP16 values at once). I don't think a larger chunk size would be needed for e.g. q4_0 KV cache since my understanding is that the only relevant part for the differences in rounding error is the KQ matrix. One downside of doing this though would be that the memory locality would likely be worse.

@JohannesGaessler
Copy link
Collaborator Author

I think the server tests also have issues with race conditions; If I increase the number of tokens from 10 to 100 the test still fails and I think the reason is that with 10 tokens the completion is too fast to actually run the completions in parallel.

@ggerganov
Copy link
Owner

For actual production use I'm not sure how to proceed. With CUDA at least you cannot guarantee bit-for-bit identical results with the current unified KV cache because unlike with the CPU backend the insertion of masked values between the relevant changes the rounding error because you have 32 threads calculating partial sums in parallel. I think this could be fixed if the unified KV cache were to allocate memory to sequences in chunks of at least 64 values (64 because the optimal memory access patterns for the 32 threads in a CUDA warp is to read in 64 FP16 values at once). I don't think a larger chunk size would be needed for e.g. q4_0 KV cache since my understanding is that the only relevant part for the differences in rounding error is the KQ matrix. One downside of doing this though would be that the memory locality would likely be worse.

Yes, this is likely a possible solution. We have to maintain "active" chunks per sequence and when searching for free slots, we will first look into those chunks. The main goal is to guarantee that every KV cache interval [k*64, (k+1)*64) always contains tokens from at most 1 sequence. If we achieve that, then the results will be deterministic across all backends

@JohannesGaessler JohannesGaessler merged commit 1b01f06 into ggerganov:master May 19, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request examples python python script changes Review Complexity : Medium Generally require more time to grok but manageable by beginner to medium expertise level server testing Everything test related
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants