Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server: fix incorrectly reported token probabilities #7125

Merged

Conversation

JohannesGaessler
Copy link
Collaborator

Fixes #7093 .

The problem on master is that the server unconditionally reports the token probabilities stored for the top tokens even if they were not actually considered for sampling. This then, among other things, causes the sum of token probabilities to exceed 100%. In total this PR fixes the following issues:

  • Tokens failing top_p/min_p check being reported to have nonzero probability.
  • Setting a temperature <= 0 not actually resulting in the top token having a reported probability of 1 with all others being reported as 0.
  • When setting n_probs > top_k the returned "top tokens" are essentially undefined.

To make this work I am extending ctx_sampling with a property n_considered that stores how many of the top tokens were actually used for sampling. This can then be used to determine if and how many top tokens need to be fetched and starting at what index all token probabilities should be zero.

Copy link
Contributor

github-actions bot commented May 7, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 547 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8601.74ms p(95)=20636.04ms fails=, finish reason: stop=485 truncated=62
  • Prompt processing (pp): avg=108.34tk/s p(95)=524.95tk/s
  • Token generation (tg): avg=34.2tk/s p(95)=48.03tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=sampling-fix-p-sum commit=ae8235dc0a7af9c169177bc19ccb38c3e55e1f82

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1715110595 --> 1715111225
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 626.66, 626.66, 626.66, 626.66, 626.66, 570.42, 570.42, 570.42, 570.42, 570.42, 608.7, 608.7, 608.7, 608.7, 608.7, 687.63, 687.63, 687.63, 687.63, 687.63, 691.02, 691.02, 691.02, 691.02, 691.02, 693.52, 693.52, 693.52, 693.52, 693.52, 708.1, 708.1, 708.1, 708.1, 708.1, 727.32, 727.32, 727.32, 727.32, 727.32, 727.0, 727.0, 727.0, 727.0, 727.0, 748.35, 748.35, 748.35, 748.35, 748.35, 743.18, 743.18, 743.18, 743.18, 743.18, 777.45, 777.45, 777.45, 777.45, 777.45, 804.78, 804.78, 804.78, 804.78, 804.78, 845.0, 845.0, 845.0, 845.0, 845.0, 794.37, 794.37, 794.37, 794.37, 794.37, 797.66, 797.66, 797.66, 797.66, 797.66, 806.68, 806.68, 806.68, 806.68, 806.68, 820.25, 820.25, 820.25, 820.25, 820.25, 820.32, 820.32, 820.32, 820.32, 820.32, 827.92, 827.92, 827.92, 827.92, 827.92, 827.46, 827.46, 827.46, 827.46, 827.46, 832.36, 832.36, 832.36, 832.36, 832.36, 847.46, 847.46, 847.46, 847.46, 847.46, 847.36, 847.36, 847.36, 847.36, 847.36, 849.61, 849.61, 849.61, 849.61, 849.61, 864.56, 864.56, 864.56, 864.56, 864.56, 859.37, 859.37, 859.37, 859.37, 859.37, 858.9, 858.9, 858.9, 858.9, 858.9, 856.99, 856.99, 856.99, 856.99, 856.99, 858.27, 858.27, 858.27, 858.27, 858.27, 861.16, 861.16, 861.16, 861.16, 861.16, 861.7, 861.7, 861.7, 861.7, 861.7, 863.83, 863.83, 863.83, 863.83, 863.83, 880.4, 880.4, 880.4, 880.4, 880.4, 880.42, 880.42, 880.42, 880.42, 880.42, 885.0, 885.0, 885.0, 885.0, 885.0, 883.48, 883.48, 883.48, 883.48, 883.48, 882.11, 882.11, 882.11, 882.11, 882.11, 882.83, 882.83, 882.83, 882.83, 882.83, 883.52, 883.52, 883.52, 883.52, 883.52, 883.94, 883.94, 883.94, 883.94, 883.94, 893.01, 893.01, 893.01, 893.01, 893.01, 897.2, 897.2, 897.2, 897.2, 897.2, 885.13, 885.13, 885.13, 885.13, 885.13, 884.25, 884.25, 884.25, 884.25, 884.25, 881.65, 881.65, 881.65, 881.65, 881.65, 874.97, 874.97, 874.97, 874.97, 874.97, 877.78, 877.78, 877.78, 877.78, 877.78, 878.81, 878.81, 878.81, 878.81, 878.81, 882.79, 882.79, 882.79, 882.79, 882.79, 884.84, 884.84, 884.84, 884.84, 884.84, 889.16, 889.16, 889.16, 889.16, 889.16, 887.46, 887.46, 887.46, 887.46, 887.46, 891.67, 891.67, 891.67, 891.67, 891.67, 892.33, 892.33, 892.33, 892.33, 892.33, 893.28, 893.28, 893.28, 893.28, 893.28, 893.93, 893.93, 893.93, 893.93, 893.93, 893.8, 893.8, 893.8, 893.8, 893.8, 893.95, 893.95, 893.95, 893.95, 893.95, 896.72, 896.72, 896.72, 896.72, 896.72, 896.3, 896.3, 896.3, 896.3, 896.3, 896.3]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1715110595 --> 1715111225
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 43.69, 43.69, 43.69, 43.69, 43.69, 38.76, 38.76, 38.76, 38.76, 38.76, 37.21, 37.21, 37.21, 37.21, 37.21, 37.04, 37.04, 37.04, 37.04, 37.04, 36.94, 36.94, 36.94, 36.94, 36.94, 37.06, 37.06, 37.06, 37.06, 37.06, 37.68, 37.68, 37.68, 37.68, 37.68, 37.78, 37.78, 37.78, 37.78, 37.78, 36.27, 36.27, 36.27, 36.27, 36.27, 36.17, 36.17, 36.17, 36.17, 36.17, 36.29, 36.29, 36.29, 36.29, 36.29, 36.08, 36.08, 36.08, 36.08, 36.08, 35.02, 35.02, 35.02, 35.02, 35.02, 34.65, 34.65, 34.65, 34.65, 34.65, 33.43, 33.43, 33.43, 33.43, 33.43, 33.54, 33.54, 33.54, 33.54, 33.54, 33.86, 33.86, 33.86, 33.86, 33.86, 33.48, 33.48, 33.48, 33.48, 33.48, 33.21, 33.21, 33.21, 33.21, 33.21, 32.73, 32.73, 32.73, 32.73, 32.73, 32.54, 32.54, 32.54, 32.54, 32.54, 32.58, 32.58, 32.58, 32.58, 32.58, 32.43, 32.43, 32.43, 32.43, 32.43, 32.48, 32.48, 32.48, 32.48, 32.48, 32.52, 32.52, 32.52, 32.52, 32.52, 32.66, 32.66, 32.66, 32.66, 32.66, 32.34, 32.34, 32.34, 32.34, 32.34, 32.0, 32.0, 32.0, 32.0, 32.0, 32.01, 32.01, 32.01, 32.01, 32.01, 32.19, 32.19, 32.19, 32.19, 32.19, 32.33, 32.33, 32.33, 32.33, 32.33, 32.51, 32.51, 32.51, 32.51, 32.51, 32.64, 32.64, 32.64, 32.64, 32.64, 32.59, 32.59, 32.59, 32.59, 32.59, 32.39, 32.39, 32.39, 32.39, 32.39, 32.19, 32.19, 32.19, 32.19, 32.19, 32.18, 32.18, 32.18, 32.18, 32.18, 32.11, 32.11, 32.11, 32.11, 32.11, 32.12, 32.12, 32.12, 32.12, 32.12, 32.16, 32.16, 32.16, 32.16, 32.16, 32.27, 32.27, 32.27, 32.27, 32.27, 32.28, 32.28, 32.28, 32.28, 32.28, 32.12, 32.12, 32.12, 32.12, 32.12, 31.88, 31.88, 31.88, 31.88, 31.88, 30.99, 30.99, 30.99, 30.99, 30.99, 30.7, 30.7, 30.7, 30.7, 30.7, 30.27, 30.27, 30.27, 30.27, 30.27, 30.19, 30.19, 30.19, 30.19, 30.19, 30.35, 30.35, 30.35, 30.35, 30.35, 30.43, 30.43, 30.43, 30.43, 30.43, 30.51, 30.51, 30.51, 30.51, 30.51, 30.54, 30.54, 30.54, 30.54, 30.54, 30.48, 30.48, 30.48, 30.48, 30.48, 30.36, 30.36, 30.36, 30.36, 30.36, 30.35, 30.35, 30.35, 30.35, 30.35, 30.38, 30.38, 30.38, 30.38, 30.38, 30.55, 30.55, 30.55, 30.55, 30.55, 30.63, 30.63, 30.63, 30.63, 30.63, 30.73, 30.73, 30.73, 30.73, 30.73, 30.79, 30.79, 30.79, 30.79, 30.79, 30.79, 30.79, 30.79, 30.79, 30.79, 30.9]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1715110595 --> 1715111225
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18, 0.18, 0.18, 0.18, 0.18, 0.33, 0.33, 0.33, 0.33, 0.33, 0.1, 0.1, 0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.2, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.12, 0.12, 0.12, 0.12, 0.12, 0.11, 0.11, 0.11, 0.11, 0.11, 0.2, 0.2, 0.2, 0.2, 0.2, 0.18, 0.18, 0.18, 0.18, 0.18, 0.36, 0.36, 0.36, 0.36, 0.36, 0.17, 0.17, 0.17, 0.17, 0.17, 0.14, 0.14, 0.14, 0.14, 0.14, 0.17, 0.17, 0.17, 0.17, 0.17, 0.29, 0.29, 0.29, 0.29, 0.29, 0.26, 0.26, 0.26, 0.26, 0.26, 0.24, 0.24, 0.24, 0.24, 0.24, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.34, 0.34, 0.34, 0.34, 0.34, 0.32, 0.32, 0.32, 0.32, 0.32, 0.23, 0.23, 0.23, 0.23, 0.23, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.14, 0.14, 0.14, 0.14, 0.14, 0.2, 0.2, 0.2, 0.2, 0.2, 0.16, 0.16, 0.16, 0.16, 0.16, 0.21, 0.21, 0.21, 0.21, 0.21, 0.25, 0.25, 0.25, 0.25, 0.25, 0.12, 0.12, 0.12, 0.12, 0.12, 0.21, 0.21, 0.21, 0.21, 0.21, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.35, 0.35, 0.35, 0.35, 0.35, 0.46, 0.46, 0.46, 0.46, 0.46, 0.5, 0.5, 0.5, 0.5, 0.5, 0.49, 0.49, 0.49, 0.49, 0.49, 0.41, 0.41, 0.41, 0.41, 0.41, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.15, 0.15, 0.15, 0.15, 0.15, 0.12, 0.12, 0.12, 0.12, 0.12, 0.18, 0.18, 0.18, 0.18, 0.18, 0.32, 0.32, 0.32, 0.32, 0.32, 0.23, 0.23, 0.23, 0.23, 0.23, 0.2, 0.2, 0.2, 0.2, 0.2, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.23, 0.23, 0.23, 0.23, 0.23, 0.23, 0.23, 0.23, 0.23, 0.23, 0.3]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 547 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1715110595 --> 1715111225
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0, 3.0, 3.0, 8.0, 8.0, 8.0, 8.0, 8.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 6.0, 6.0, 6.0, 6.0, 6.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 8.0, 8.0, 8.0, 8.0, 8.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 0.0]
                    
Loading

llama_sample_top_k(ctx, &cur_p, n_probs, 0);
}

if (slot.sparams.temp <= 0.0f) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check just equality?

Suggested change
if (slot.sparams.temp <= 0.0f) {
if (slot.sparams.temp == 0.0f) {

For sub-zero temperatures, we calculate the probabilities, though they are never used:

if (temp < 0.0) {
// greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id;
} else if (temp == 0.0) {
// greedy sampling, no probs
id = llama_sample_token_greedy(ctx_main, &cur_p);

I guess either way makes sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I think that it's better to be consistent if at all possible. I changed it to an equality and added a note to the documentation. (I also noticed and fixed a bug where for temperature 0.0f the "top tokens" were wrong.)

@JohannesGaessler JohannesGaessler merged commit af0a5b6 into ggerganov:master May 7, 2024
64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Completion probabilities exceed 100% with top_p < 1
2 participants