Skip to content

cuda : use amd wave sharing intrinsics for warp_reduce functions #6522

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Engininja2
Copy link
Contributor

The warp_reduce functions use __shfl_xor_sync() which on AMD HIP gets turned into the ds_bpermute instruction which requires some setup instructions and uses local data share bandwidth and latency. AMD GPUs also support Data-Parallel Primitive (DPP) instructions which can perform certain forms of sharing within a warp for free or nearly free.

Using the DPP intrinsic in this PR my GPU gets 1% faster token generation times, but oddly 0-1% slower prompt processing times. Maybe it has to do with there being 4 fewer waits on LDS memory each reduction and how the GPU schedules warps, but I really have no idea. Other GPU models might behave differently.

Details on the instructions available for sharing within a warp are on AMD's website and within each ISA's documentation.
I'd like to note that in the event that a warp_reduce function is called with some threads inactive, that the results will differ between GFX9 GPUs and GFX10+, but I don't think this matters since according to the CUDA documentation performing __shfl_xor_sync() with inactive threads is undefined on NVIDIA GPUs anyways.

Here are my llama-bench results:

GPU Model Model Size [GiB] Test t/s master t/s PR Speedup
RX 5700 XT llama 7B Q4_K_S 3.59 pp512 300.78 298.79 0.99
RX 5700 XT llama 7B Q4_K_S 3.59 tg128 46.63 47.18 1.01
RX 5700 XT llama 7B Q8_0 6.67 pp512 345.31 341.53 0.99
RX 5700 XT llama 7B Q8_0 6.67 tg128 38.99 39.38 1.01
RX 5700 XT mpt 7B Q4_0 3.64 pp512 339.65 339.53 1.00
RX 5700 XT mpt 7B Q4_0 3.64 tg128 63.27 63.90 1.01
RX 5700 XT orion 14B Q3_K_S 5.96 pp512 84.05 83.88 1.00
RX 5700 XT orion 14B Q3_K_S 5.96 tg128 15.91 16.10 1.01
RX 5700 XT phi2 3B Q8_0 2.75 pp512 783.94 783.35 1.00
RX 5700 XT phi2 3B Q8_0 2.75 tg128 76.15 76.67 1.01
RX 5700 XT starcoder2 7B Q4_K_M 4.22 pp512 280.40 279.28 1.00
RX 5700 XT starcoder2 7B Q4_K_M 4.22 tg128 39.39 39.67 1.01

I also tried testing SOFT_MAX and saw between 0-8% speedup except for cases with [1024,1024,1,1] or [1023,1023,1,1] that had max bias of 8, which were ~5% slower. Again I'm not sure why.

@JohannesGaessler
Copy link
Collaborator

In a quick test I am essentially getting the same results on my RX 6800.

@JohannesGaessler
Copy link
Collaborator

Sorry, what I said was misleading. What I meant to say is that I also only observe a small speedup. But I am not observing a performance regression. These are the specific numbers:

GPU Model Model Size [GiB] Test t/s master t/s a37d885 Speedup
RX 6800 llama 7B Q4_0 3.83 pp512 715.17 715.80 1.00
RX 6800 llama 7B Q4_0 3.83 tg128 61.00 61.11 1.00
RX 6800 llama 7B Q4_K_S 3.86 pp512 586.22 586.72 1.00
RX 6800 llama 7B Q4_K_S 3.86 tg128 43.33 43.74 1.01
RX 6800 llama 7B Q8_0 7.17 pp512 726.42 726.25 1.00
RX 6800 llama 7B Q8_0 7.17 tg128 41.46 41.70 1.01

@mofosyne mofosyne added Review Complexity : High Generally require indepth knowledge of LLMs or GPUs performance Speed related topics labels May 10, 2024
@sorasoras
Copy link

What is the reason for not merge this pr?

@JohannesGaessler
Copy link
Collaborator

@Engininja2 reported an inconsistent and minimal change in performance and there have been no further updates from him since. As such I don't think it makes sense to merge this PR in its current state, especially when I don't have a good understanding of what this code actually does.

@sorasoras
Copy link

sorasoras commented May 11, 2024

@Engininja2 it seems like recent merge is failed compile on my 7900XTX at windows.
I got some decent performance for larger batch inference.

168 warnings and 4 errors generated when compiling for gfx1100.
[28/36] Building CXX object CMakeFiles/ggml.dir/ggml-cuda.cu.obj
FAILED: CMakeFiles/ggml.dir/ggml-cuda.cu.obj
ccache C:\PROGRA~1\AMD\ROCm\5.7\bin\CLANG_~1.EXE -DGGML_CUDA_DMMV_X=512 -DGGML_CUDA_FORCE_DMMV -DGGML_CUDA_MMV_Y=32 -DGGML_SCHED_MAX_COPIES=4 -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_USE_LLAMAFILE -DK_QUANTS_PER_ITERATION=2 -D_CRT_SECURE_NO_WARNINGS -D_XOPEN_SOURCE=600 -D__HIP_PLATFORM_AMD__=1 -D__HIP_PLATFORM_HCC__=1 -IW:/git/test/engininja2/llama.cpp/. -isystem "C:/Program Files/AMD/ROCm/5.7/include" -O3 -DNDEBUG -D_DLL -D_MT -Xclang --dependent-lib=msvcrt -std=gnu++14 -Wmissing-declarations -Wmissing-noreturn -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wunreachable-code-break -Wunreachable-code-return -Wmissing-prototypes -Wextra-semi -march=native -mllvm -amdgpu-early-inline-all=true -mllvm -amdgpu-function-calls=false -x hip --offload-arch=gfx1100 -MD -MT CMakeFiles/ggml.dir/ggml-cuda.cu.obj -MF CMakeFiles\ggml.dir\ggml-cuda.cu.obj.d -o CMakeFiles/ggml.dir/ggml-cuda.cu.obj -c W:/git/test/engininja2/llama.cpp/ggml-cuda.cu
In file included from W:/git/test/engininja2/llama.cpp/ggml-cuda.cu:5:
In file included from W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:15:
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:154:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:175:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:196:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:218:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:263:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
W:/git/test/engininja2/llama.cpp/.\ggml-common.h:290:9: warning: anonymous types declared in an anonymous union are an extension [-Wnested-anon-types]
        struct {
        ^
In file included from W:/git/test/engininja2/llama.cpp/ggml-cuda.cu:5:
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:252:1: warning: function declared 'noreturn' should not return [-Winvalid-noreturn]
}
^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:469:24: error: redefinition of 'no_device_code'
static __device__ void no_device_code(
                       ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:238:24: note: previous definition is here
static __device__ void no_device_code(
                       ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:486:9: warning: 'NO_DEVICE_CODE' macro redefined [-Wmacro-redefined]
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:309:9: note: previous definition is here
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:491:41: error: redefinition of 'warp_reduce_sum'
static __device__ __forceinline__ float warp_reduce_sum(float x) {
                                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:312:41: note: previous definition is here
static __device__ __forceinline__ float warp_reduce_sum(float x) {
                                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:499:42: error: redefinition of 'warp_reduce_sum'
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
                                         ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:324:42: note: previous definition is here
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
                                         ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:509:5: warning: macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]
#if FP16_AVAILABLE
    ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:460:25: note: expanded from macro 'FP16_AVAILABLE'
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:509:5: warning: macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:460:54: note: expanded from macro 'FP16_AVAILABLE'
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
                                                     ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:533:41: error: redefinition of 'warp_reduce_max'
static __device__ __forceinline__ float warp_reduce_max(float x) {
                                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:352:41: note: previous definition is here
static __device__ __forceinline__ float warp_reduce_max(float x) {
                                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:542:5: warning: macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]
#if FP16_AVAILABLE
    ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:460:25: note: expanded from macro 'FP16_AVAILABLE'
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
                        ^
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:542:5: warning: macro expansion producing 'defined' has undefined behavior [-Wexpansion-to-defined]
W:/git/test/engininja2/llama.cpp/./ggml-cuda/common.cuh:460:54: note: expanded from macro 'FP16_AVAILABLE'
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
                                                     ^
W:/git/test/engininja2/llama.cpp/ggml-cuda.cu:1873:10: warning: variable 'any_pascal_with_slow_fp16' set but not used [-Wunused-but-set-variable]
    bool any_pascal_with_slow_fp16 = false;
         ^
W:/git/test/engininja2/llama.cpp/ggml-cuda.cu:2982:62: warning: unused parameter 'buffer' [-Wunused-parameter]
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
                                                             ^
W:/git/test/engininja2/llama.cpp/ggml-cuda.cu:2982:77: warning: unused parameter 'size' [-Wunused-parameter]
GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
                                                                            ^
15 warnings and 4 errors generated when compiling for gfx1100.
ninja: build stopped: subcommand failed.

@Engininja2
Copy link
Contributor Author

I rebased the PR on current master. llama.cpp doesn't spend much time synchronizing between GPU threads so a tiny performance increase probably isn't worth the maintenance burden of additional code.

Would this be better converted to a draft PR, or closed?

@JohannesGaessler
Copy link
Collaborator

Check performance for

./llama-bench --model models/opt/${model_name}-${quantization}.gguf -r 1 -fa 1 -n 0 -pg 0,0 -p 4096 -b 1,512

with an actual path to a model. The new FlashAttention kernels (I think) are more sensitive to the changes in this PR.

@JohannesGaessler
Copy link
Collaborator

These are the results I get:

GPU Model Batch size Test t/s fed0108 t/s 9e6f2e2 Speedup
RX 6800 llama 8B Q8_0 1 pp4096 27.25 28.16 1.03
RX 6800 llama 8B Q8_0 512 pp4096 380.49 374.30 0.98

@Engininja2
Copy link
Contributor Author

Unfortunately I get a large performance regression for flash attention.

GPU Model Batch size Test t/s master t/s amd-warp-reduce Speedup
RX 5700 XT llama 7B Q4_K_S 1 pp4096 23.62 23.55 1.00
RX 5700 XT llama 7B Q4_K_S 512 pp4096 211.17 163.15 0.77
RX 5700 XT starcoder2 7B Q4_K_M 1 pp4096 21.12 21.47 1.02
RX 5700 XT starcoder2 7B Q4_K_M 512 pp4096 193.14 149.18 0.77

For gfx1010 flash_attn_vec_ext_f16<128, 8, 1>, the instance llama-bench is using here, spills 26 vgprs on master, and 185 on this PR. Applying launch_bounds(D,1) has it spill 55 vgprs, but it's using twice as many vgprs, for 177.58 t/s, 16% slower than master.

I tried replacing directly adding half2s with adding the x and y components and got this result with 75 vgprs spilled:

model n_batch test t/s
llama 7B Q4_K - Small 1 pp4096 23.72 ± 0.00
llama 7B Q4_K - Small 512 pp4096 201.41 ± 0.00

It's better but still a 5% slowdown, and feels like a bad idea to be hoping that compiler heuristics happen to go the right way.

@JohannesGaessler
Copy link
Collaborator

Generally speaking the current FlashAttention kernels are bad for large batch sizes anyways so the real question is how the performance will change for kernels that are specifically written for large batch sizes. I think a PR that gives +3% performance in some cases but -23% in others should not be merged.

@sorasoras
Copy link

It does provide decent speed up for token generation about 2-3% which is useful.

@jeroen-mostert
Copy link
Contributor

jeroen-mostert commented Aug 29, 2024

FWIW, compared to the current master, I can't replicate the regression on my RX 6800 XT, and in fact for large batches the speedup increases accordingly when using flash attention. I did not observe any change in token generation speed (positive or negative).

GPU Model Batch size FlashAttention Test t/s 1d1ccce t/s dpp_wave Speedup
RX 6800 XT llama 7B Q4_0 1 No pp512 79.08 79.52 1.01
RX 6800 XT llama 7B Q4_0 1 No pp4096 44.76 44.85 1.00
RX 6800 XT llama 7B Q4_0 1 Yes pp512 80.56 81.17 1.01
RX 6800 XT llama 7B Q4_0 1 Yes pp4096 59.32 60.13 1.01
RX 6800 XT llama 7B Q4_0 512 No pp512 1685.17 1694.99 1.01
RX 6800 XT llama 7B Q4_0 512 No pp4096 1123.16 1124.75 1.00
RX 6800 XT llama 7B Q4_0 512 Yes pp512 1553.67 1628.39 1.05
RX 6800 XT llama 7B Q4_0 512 Yes pp4096 773.28 869.77 1.12
RX 6800 XT llama 7B Q4_0 4096 No pp512 1684.82 1687.73 1.00
RX 6800 XT llama 7B Q4_0 4096 No pp4096 1111.66 1112.76 1.00
RX 6800 XT llama 7B Q4_0 4096 Yes pp512 1551.82 1627.11 1.05
RX 6800 XT llama 7B Q4_0 4096 Yes pp4096 767.88 863.89 1.13

So this may be worth revisiting and/or gating behind specific model checks. A 12% boost is nothing to sneeze at.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Speed related topics Review Complexity : High Generally require indepth knowledge of LLMs or GPUs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants