Skip to content

Conversation

@bssrdf
Copy link

@bssrdf bssrdf commented Sep 18, 2025

I added support for fp16. Please review. Thanks.

@Green-Sky
Copy link

[INFO ] ggml_extend.hpp:1648 - vae offload params ( 94.47 MB, 140 tensors) to runtime backend (CUDA0), taking 0.01s
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1920.19 MB(VRAM)
[ERROR] ggml_extend.hpp:71   - CUDA error: an illegal memory access was encountered
[ERROR] ggml_extend.hpp:71   -   current device: 0, in function ggml_backend_cuda_synchronize at /build/8ic4ww2ssjfv047iarl54v0dprp5d07b-source/ggml/src/ggml-cuda/ggml-cuda.cu:2628
[ERROR] ggml_extend.hpp:71   -   cudaStreamSynchronize(cuda_ctx->stream())
/build/8ic4ww2ssjfv047iarl54v0dprp5d07b-source/ggml/src/ggml-cuda/ggml-cuda.cu:88: CUDA error

@bssrdf
Copy link
Author

bssrdf commented Sep 19, 2025

@Green-Sky, thanks for testing. it's weird that it failed in sd.cpp. This PR passed all tests in test_backend_op.

Could you try @etasnadi's branch https://github.com/etasnadi/llama.cppxx/tree/conv2d-cuda-opt? Mine is on top his and I hope I didn't break anything.

@Green-Sky
Copy link

Well, without the f16 support, it falls back to the slow naive impl, as expected. (working as intented)

@Green-Sky
Copy link

Your changes look alright too.

The only thing that I think might be wrong somewhere is that we might be not accounting for the halving of the size of the kernel somehow.

@Green-Sky
Copy link

Looks like it is not your change that causes the issue. I will report in the pr in a sec.

@etasnadi
Copy link
Owner

Don't forget to update this if you want to use f16:

if (!getenv("GGML_CUDA_USE_LEGACY_CONV") &&

@bssrdf
Copy link
Author

bssrdf commented Sep 19, 2025

Don't forget to update this if you want to use f16:

if (!getenv("GGML_CUDA_USE_LEGACY_CONV") &&

In my tests, I don't have GGML_CUDA_USE_LEGACY_CONV set, so !getenv("GGML_CUDA_USE_LEGACY_CONV") == true

@bssrdf
Copy link
Author

bssrdf commented Sep 19, 2025

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.

Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.

(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

@etasnadi
Copy link
Owner

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.

Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

@mnehete32
Copy link

@Green-Sky, thanks for testing. it's weird that it failed in sd.cpp. This PR passed all tests in test_backend_op.

Could you try @etasnadi's branch https://github.com/etasnadi/llama.cppxx/tree/conv2d-cuda-opt? Mine is on top his and I hope I didn't break anything.

It is same for me, for the tensor core kernel test cases passed, but when I try to test in sd.cpp, it didn't work.

@bssrdf
Copy link
Author

bssrdf commented Sep 22, 2025

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

@etasnadi
Copy link
Owner

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?

Thanks.

@bssrdf
Copy link
Author

bssrdf commented Sep 22, 2025

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?

Thanks.

Please see https://github.com/bssrdf/llama.cpp/tree/add-conv2d-test-case for a test I added. The above numbers are for this PR. Thanks.

@etasnadi
Copy link
Owner

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?
Thanks.

Please see https://github.com/bssrdf/llama.cpp/tree/add-conv2d-test-case for a test I added. The above numbers are for this PR. Thanks.

That's reasonable. The CUDA backend uses cuBLAS by default for performing the matrix multiplication and it is highly optimized for each shape classes on each single arch. I knew that this is the case so that is the reason I did not share the CUDA kernel in the first place. In order to beat the im2col+gemm implementation we also need to optimize the code for each device that seemed to me almost impossible. I know that many github repos claim that they outperform cuBLAS but I doubt that they can keep their advantage on each device and each shape.

Nevertheless, the memory saving is huge, so it might worth it to add the direct conv2d code even though it is considerably slower in several cases.

Now I try to reproduce the crash with stable diffusion and once it is fixed, we are ready.

@bssrdf
Copy link
Author

bssrdf commented Sep 23, 2025

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?
Thanks.

Please see https://github.com/bssrdf/llama.cpp/tree/add-conv2d-test-case for a test I added. The above numbers are for this PR. Thanks.

That's reasonable. The CUDA backend uses cuBLAS by default for performing the matrix multiplication and it is highly optimized for each shape classes on each single arch. I knew that this is the case so that is the reason I did not share the CUDA kernel in the first place. In order to beat the im2col+gemm implementation we also need to optimize the code for each device that seemed to me almost impossible. I know that many github repos claim that they outperform cuBLAS but I doubt that they can keep their advantage on each device and each shape.

Nevertheless, the memory saving is huge, so it might worth it to add the direct conv2d code even though it is considerably slower in several cases.

Now I try to reproduce the crash with stable diffusion and once it is fixed, we are ready.

@etasnadi, I agree with everything you said. However, @Green-Sky showed vulkan conv2d direct can be faster than cuda im2col. I looked at vulkan code and found nothing specially optimized. Maybe the vulkan compiler can do a magic job of optimizing hell of it.

@Green-Sky
Copy link

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?
Thanks.

Please see https://github.com/bssrdf/llama.cpp/tree/add-conv2d-test-case for a test I added. The above numbers are for this PR. Thanks.

That's reasonable. The CUDA backend uses cuBLAS by default for performing the matrix multiplication and it is highly optimized for each shape classes on each single arch. I knew that this is the case so that is the reason I did not share the CUDA kernel in the first place. In order to beat the im2col+gemm implementation we also need to optimize the code for each device that seemed to me almost impossible. I know that many github repos claim that they outperform cuBLAS but I doubt that they can keep their advantage on each device and each shape.
Nevertheless, the memory saving is huge, so it might worth it to add the direct conv2d code even though it is considerably slower in several cases.
Now I try to reproduce the crash with stable diffusion and once it is fixed, we are ready.

@etasnadi, I agree with everything you said. However, @Green-Sky showed vulkan conv2d direct can be faster than cuda im2col. I looked at vulkan code and found nothing specially optimized. Maybe the vulkan compiler can do a magic job of optimizing hell of it.

I think im2col is somewhat memory bound, and the fact that the vulkan direct implementation uses less memory helps. (with cm2)

@etasnadi
Copy link
Owner

etasnadi commented Sep 24, 2025

I figured out what is the problem. I mistakenly forgot to use the proper stream for copying to the symbols. @bssrdf can you update the code and test if it works on your device if you use the async alternative of cudaMemcpyToSymbol with the ctx.stream()?

This is not an issue with test-backend-ops because it probably uses the same (default) stream for both the kernel call and the copy to symbols.

@etasnadi
Copy link
Owner

I am also puzzled by: no matter what test cases I throw at it in test_backend_op, they all passed; in a customized test I created for my implicit PR, it always failed with an illegal memory access was encountered error.
Edit: I fixed my test case and got the benchmark comparison with im2col+gemm.
(IC, OC, IW, IH) im2col+GEMM TIME im2col+GEMM VRAM implicit TIME implicit VRAM
(64, 64, 48, 64) 0.06 ms 4.12 MB 0.03 ms 0.75 MB
(320, 320, 104, 152) 0.91 ms 106.13 MB 1.53 ms 19.30 MB
(640, 640, 52, 76) 0.46 ms 53.07 MB 2.01 ms 9.65 MB
(640, 640, 104, 152) 2.17 ms 212.27 MB 5.56 ms 38.59 MB
(960, 320, 104, 152) 2.01 ms 279.80 MB 4.19 ms 19.30 MB
(1280, 1280, 26, 38) 0.25 ms 26.53 MB 1.74 ms 4.82 MB
(1280, 640, 52, 76) 0.86 ms 96.48 MB 3.55 ms 9.65 MB
(1920, 1280, 26, 38) 0.34 ms 37.39 MB 2.60 ms 4.82 MB
(2560, 1280, 26, 38) 0.50 ms 48.24 MB 3.66 ms 4.82 MB
(512, 512, 104, 152) 1.25 ms 169.81 MB 2.89 ms 30.88 MB
(512, 512, 208, 304) 4.85 ms 679.25 MB 10.97 ms 123.50 MB
(512, 256, 416, 608) 15.65 ms 2470.00 MB 21.08 ms 247.00 MB
(256, 128, 832, 1216) 33.37 ms 4940.00 MB 28.77 ms 494.00 MB
(256, 256, 832, 1216) 39.64 ms 5434.00 MB 40.23 ms 988.00 MB
(320, 256, 1024, 1920) 74.01 ms 12720.00 MB 112.98 ms 1920.00 MB

Is it the f16 or f32 perf? The perf is not looking too good. Can you test what's the case with the Vulkan backend?

This is fp16 results. I don't have a vulkan dev env, sorry.

Ok, I thought that this is sd.cpp results. I guess it is test-backend-ops with custom shapes. Can you share your test cases you added to test-backend-ops? These numbers are for this PR or your previous implicit GEMM implementation?
Thanks.

Please see https://github.com/bssrdf/llama.cpp/tree/add-conv2d-test-case for a test I added. The above numbers are for this PR. Thanks.

That's reasonable. The CUDA backend uses cuBLAS by default for performing the matrix multiplication and it is highly optimized for each shape classes on each single arch. I knew that this is the case so that is the reason I did not share the CUDA kernel in the first place. In order to beat the im2col+gemm implementation we also need to optimize the code for each device that seemed to me almost impossible. I know that many github repos claim that they outperform cuBLAS but I doubt that they can keep their advantage on each device and each shape.
Nevertheless, the memory saving is huge, so it might worth it to add the direct conv2d code even though it is considerably slower in several cases.
Now I try to reproduce the crash with stable diffusion and once it is fixed, we are ready.

@etasnadi, I agree with everything you said. However, @Green-Sky showed vulkan conv2d direct can be faster than cuda im2col. I looked at vulkan code and found nothing specially optimized. Maybe the vulkan compiler can do a magic job of optimizing hell of it.

Vulkan already has matrix cores support I guess that's the reason why it is faster. The scalar CUDA kernel has very similar performance to Vulkan on my device and it is faster with the bank-conflict fix as the Vulkan code have not received it.

@bssrdf
Copy link
Author

bssrdf commented Sep 24, 2025

cudaMemcpyToSymbol

@etasnadi, I tried replacing cudaMemcpyToSymbol(dp, &p, sizeof(Params)); with cudaMemcpyToSymbolAsync(dp, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, stream);, but it now even failed at test_backend_op. With non-async version, test_backend_op works.

@Green-Sky
Copy link

-    cudaMemcpyToSymbol(dp, &p, sizeof(Params));
+    cudaMemcpyToSymbolAsync(dp, &p, sizeof(Params), 0, cudaMemcpyHostToDevice, ctx.stream());

actually works for sd.cpp

@Green-Sky
Copy link

Green-Sky commented Sep 24, 2025

768x1024 sd1 fp16 vae:

method time memory
CUDA imcol+mul ~1.68s 4992.19 MB
CUDA direct (master) ~35.35s 1920.19 MB
CUDA direct (6049576) (mnehete32 pr) ~5.05s 1920.19 MB
CUDA implicitgemm (2ec76aa) ~2.20s 1920.19 MB
CUDA direct (etasnadi pr + fp16) ~2.13s 1920.19 MB
VULKAN imcol+mul OOM ~4992 MB
VULKAN direct ~1.17s 1920.19 MB

edit: it is within error with @bssrdf implicitgemm pr, for thermal reasons.

@etasnadi
Copy link
Owner

768x1024 sd1 fp16 vae:

method time memory
CUDA imcol+mul ~1.68s 4992.19 MB
CUDA direct (master) ~35.35s 1920.19 MB
CUDA direct (6049576) (mnehete32 pr) ~5.05s 1920.19 MB
CUDA implicitgemm (2ec76aa) ~2.20s 1920.19 MB
CUDA direct (etasnadi pr + fp16) ~2.13s 1920.19 MB
VULKAN imcol+mul OOM ~4992 MB
VULKAN direct ~1.17s 1920.19 MB

edit: it is within error with @bssrdf implicitgemm pr, for thermal reasons.

Interesting. For me, the Vulkan direct and indirect has the same perf, CUDA is considerably slower in SD. I tested on 2060. While in test-backend-ops CUDA is faster than Vulkan.

What command line are you using exactly?

@Green-Sky
Copy link

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m  --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 3 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

should be https://huggingface.co/cyberdelia/latest_s15_models/blob/main/CyberRealistic_V9_FP16.safetensors

@Green-Sky
Copy link

I probably need to rerun everything, to make sure I used the same res everywhere (likely) and bc sd.cpp/ggml got updated in between.

@bssrdf
Copy link
Author

bssrdf commented Sep 25, 2025

For some reason, cudaMemcpyToSymbolAsync() didn't work for me. Both test-backend-op and my customized case failed with illegal memory access error, but cudaMemcpyToSymbol() had no problem.

Since @Green-Sky confirmed cudaMemcpyToSymbolAsync() fixed his problem in sd.cpp, I am going to commit the async code.

@etasnadi
Copy link
Owner

etasnadi commented Sep 26, 2025

$ result/bin/sd -m models/CyberRealistic_V9_FP16.safetensors --sampling-method dpm++2m  --scheduler karras --cfg-scale 5 -W 768 -H 1024 --diffusion-fa --steps 20 -b 3 -v -n "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry" -p "a lovely cat" --vae-conv-direct --offload-to-cpu

should be https://huggingface.co/cyberdelia/latest_s15_models/blob/main/CyberRealistic_V9_FP16.safetensors

I intercepted the conv2d calls when executed your command line and added them as test cases to test-backend-ops -- so we are now perfectly simulate sd in the tests. The first 5 cases are the most important because they are called (5-10 times - see the code), the rest called only once. This PR is an order of magnitude faster than implicit GEMM on most test cases so I am puzzled why they have similar perf with sd on your device. The Vulkan kernel is indeed much faster on one test case, might be there is a bug in kernel selection but otherwise they are really similar in perf in the frequent conv2d calls.

CUDA - "implicit gemm" (bssrdf/conv2d-implicit53a2ccbe129472e66a05cd87eee2ed6b3d42a73a):
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                       6 runs - 179734.00 us/run -  57.98 GFLOP/run - 322.56 GFLOPS
  CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      2 runs - 747178.00 us/run - 231.90 GFLOP/run - 310.37 GFLOPS
  CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      2 runs - 836917.00 us/run - 231.88 GFLOP/run - 277.06 GFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     2 runs - 855596.50 us/run - 231.83 GFLOP/run - 270.95 GFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      32 runs - 61183.12 us/run -   6.44 GFLOP/run - 105.20 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                106496 runs -     9.45 us/run - 344.06 kFLOP/run -  36.42 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                 672 runs -  1717.28 us/run - 446.69 MFLOP/run - 260.12 GFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      1 runs - 3504339.00 us/run - 927.61 GFLOP/run - 264.70 GFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      1 runs - 1760432.00 us/run - 463.81 GFLOP/run - 263.46 GFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                      4 runs - 492934.00 us/run -  51.49 GFLOP/run - 104.45 GFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     1 runs - 1763091.00 us/run - 463.76 GFLOP/run - 263.04 GFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     4 runs - 493539.75 us/run -  51.44 GFLOP/run - 104.22 GFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     1 runs - 3532391.00 us/run - 927.51 GFLOP/run - 262.57 GFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      57 runs - 20324.12 us/run -   5.43 GFLOP/run - 267.34 GFLOPS

CUDA - this:
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      84 runs - 12012.67 us/run -  57.98 GFLOP/run -   4.83 TFLOPS
  CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     23 runs - 44975.83 us/run - 231.90 GFLOP/run -   5.16 TFLOPS
  CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     23 runs - 45175.78 us/run - 231.88 GFLOP/run -   5.13 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    22 runs - 46494.91 us/run - 231.83 GFLOP/run -   4.99 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     672 runs -  1498.36 us/run -   6.44 GFLOP/run -   4.30 TFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 98304 runs -    10.54 us/run - 344.06 kFLOP/run -  32.63 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                2688 runs -   384.32 us/run - 446.69 MFLOP/run -   1.16 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      6 runs - 177570.83 us/run - 927.61 GFLOP/run -   5.22 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     12 runs - 88908.50 us/run - 463.81 GFLOP/run -   5.22 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     86 runs - 11684.58 us/run -  51.49 GFLOP/run -   4.41 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    12 runs - 90303.75 us/run - 463.76 GFLOP/run -   5.14 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    74 runs - 13515.31 us/run -  51.44 GFLOP/run -   3.81 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     6 runs - 181905.67 us/run - 927.51 GFLOP/run -   5.10 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     114 runs -  8950.66 us/run -   5.43 GFLOP/run - 607.05 GFLOPS

Vulkan (master):
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      82 runs - 12219.60 us/run -  57.98 GFLOP/run -   4.74 TFLOPS
  CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     23 runs - 45087.22 us/run - 231.90 GFLOP/run -   5.14 TFLOPS
  CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     23 runs - 45303.70 us/run - 231.88 GFLOP/run -   5.12 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    22 runs - 46252.59 us/run - 231.83 GFLOP/run -   5.01 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     688 runs -  1487.66 us/run -   6.44 GFLOP/run -   4.33 TFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                131072 runs -     8.01 us/run - 344.06 kFLOP/run -  42.93 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                3584 runs -   296.46 us/run - 446.69 MFLOP/run -   1.51 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      6 runs - 179047.67 us/run - 927.61 GFLOP/run -   5.18 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     12 runs - 89600.33 us/run - 463.81 GFLOP/run -   5.18 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     92 runs - 11109.16 us/run -  51.49 GFLOP/run -   4.63 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    12 runs - 90773.58 us/run - 463.76 GFLOP/run -   5.11 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    84 runs - 12156.93 us/run -  51.44 GFLOP/run -   4.23 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     6 runs - 180964.00 us/run - 927.51 GFLOP/run -   5.13 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     133 runs -  8582.44 us/run -   5.43 GFLOP/run - 633.09 GFLOPS

The Vulkan kernel is indeed faster in a few less important cases but in the first 5 the CUDA kernel is somewhat faster.

The additional test cases I added:

    // Stable-diffusion layers
    std::map<std::string, uint32_t> idx_sd{
        { "iw",   0 },
        { "ih",   1 },
        { "kw",   2 },
        { "kh",   3 },
        { "Cout", 4 },
        { "Cin",  5 },
        { "B",    6 },
    };

    // Input image size
    uint32_t w = 768;
    uint32_t h = 1024;

    // Number of filters (base)
    uint32_t Cout_b = 128;
    uint32_t Cin_b  = 128;

    std::vector<std::array<uint32_t, 7>> cases_sd = {
        { w / 8, h / 8, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x10 (called 10 times)
        { w / 4, h / 4, 3, 3, Cout_b * 4, Cin_b * 4, 1 }, // x7
        { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 2, 1 }, // x5
        { w,     h,     3, 3, Cout_b,     Cin_b,     1 }, // x5
        { w / 8, h / 8, 1, 1, Cout_b * 4, Cin_b * 4, 1 }, // x4
        { w / 8, h / 8, 1, 1, 4,          4,         1 },
        { w / 8, h / 8, 3, 3, Cout_b * 4, 4,         1 },

        { w / 2, h / 2, 3, 3, Cout_b * 4, Cin_b * 4, 1 },
        { w / 2, h / 2, 3, 3, Cout_b * 2, Cin_b * 4, 1 },
        { w / 2, h / 2, 1, 1, Cout_b * 2, Cin_b * 4, 1 },

        { w,     h,     3, 3, Cout_b,     Cin_b * 2, 1 },
        { w,     h,     1, 1, Cout_b,     Cin_b * 2, 1 },
        { w,     h,     3, 3, Cout_b * 2, Cin_b * 2, 1 },

        { w,     h,     3, 3, 3,          Cin_b,     1 },
    };

    for (auto act_case : cases_sd) {
        GGML_ASSERT(act_case[idx_sd["kw"]] == 3 || act_case[idx_sd["kw"]] == 1);
        GGML_ASSERT(act_case[idx_sd["kh"]] == 3 || act_case[idx_sd["kh"]] == 1);

        uint32_t p0 = act_case[idx_sd["kw"]] == 3 ? 1 : 0;
        uint32_t p1 = act_case[idx_sd["kh"]] == 3 ? 1 : 0;

        test_cases.emplace_back(new test_conv_2d(
            { act_case[idx_sd["iw"]], act_case[idx_sd["ih"]], act_case[idx_sd["Cin"]], act_case[idx_sd["B"]] },
            { act_case[idx_sd["kw"]], act_case[idx_sd["kh"]], act_case[idx_sd["Cin"]], act_case[idx_sd["Cout"]] },
            GGML_TYPE_F16, 1, 1, p0, p1, 1, 1, false));
    }

@etasnadi
Copy link
Owner

I think you all measured the execution time wrong. I manually merged @bssrdf's PR and they created a different op for GGML_OP_CONV_2D_IMPLICIT, so it was never called when testing, so only @mnehete32's inefficient code was used. So I created the tests properly, and turns out that the implicit convolution is faster than the Vulkan implementation for many large matrices!

I was aware of this beause such algs (cuDNN and https://zhuanlan.zhihu.com/p/661879423 the implicit code is forked from) usually treat the channel dim as contiguous and it is more efficient in most cases.

It is still slower in some cases but it an be similarly optimized with the tricks used in the Vulkan implementation.

implicit conv:

  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     42 runs - 24344.79 us/run - 137.42 GFLOP/run -   5.64 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               29172 runs -    34.65 us/run - 133.69 MFLOP/run -   3.86 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               16951 runs -    59.88 us/run - 135.78 MFLOP/run -   2.27 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 57344 runs -    20.17 us/run - 642.82 kFLOP/run -  31.87 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 9572 runs -   113.43 us/run -  20.90 MFLOP/run - 184.22 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                16384 runs -    63.70 us/run -   2.78 MFLOP/run -  43.72 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   444.22 us/run -  22.28 MFLOP/run -  50.15 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               16473 runs -    61.54 us/run - 115.40 MFLOP/run -   1.88 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3052 runs -   333.21 us/run - 923.24 MFLOP/run -   2.77 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2310 runs -   433.73 us/run -   1.85 GFLOP/run -   4.26 TFLOPS
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     44 runs - 23199.36 us/run - 137.42 GFLOP/run -   5.92 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               29172 runs -    34.72 us/run - 133.69 MFLOP/run -   3.85 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               16951 runs -    59.68 us/run - 135.78 MFLOP/run -   2.28 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 57344 runs -    20.23 us/run - 642.82 kFLOP/run -  31.77 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 9572 runs -   113.17 us/run -  20.90 MFLOP/run - 184.65 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                16384 runs -    63.37 us/run -   2.78 MFLOP/run -  43.95 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 4489 runs -   437.00 us/run -  22.28 MFLOP/run -  50.98 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               16473 runs -    61.38 us/run - 115.40 MFLOP/run -   1.88 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3052 runs -   332.96 us/run - 923.24 MFLOP/run -   2.77 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2365 runs -   427.70 us/run -   1.85 GFLOP/run -   4.32 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     100 runs - 10044.22 us/run -  57.98 GFLOP/run -   5.77 TFLOPS
  CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     26 runs - 38876.23 us/run - 231.90 GFLOP/run -   5.97 TFLOPS
  CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     26 runs - 38843.77 us/run - 231.88 GFLOP/run -   5.97 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    26 runs - 39048.42 us/run - 231.83 GFLOP/run -   5.94 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     832 runs -  1207.22 us/run -   6.44 GFLOP/run -   5.33 TFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 49152 runs -    21.86 us/run - 344.06 kFLOP/run -  15.74 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                6944 runs -   144.86 us/run - 446.69 MFLOP/run -   3.08 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      7 runs - 153899.86 us/run - 927.61 GFLOP/run -   6.03 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     13 runs - 77167.85 us/run - 463.81 GFLOP/run -   6.01 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    110 runs -  9198.27 us/run -  51.49 GFLOP/run -   5.60 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    13 runs - 77301.08 us/run - 463.76 GFLOP/run -   6.00 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                   104 runs -  9665.38 us/run -  51.44 GFLOP/run -   5.32 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     7 runs - 154487.43 us/run - 927.51 GFLOP/run -   6.00 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      38 runs - 36557.32 us/run -   5.43 GFLOP/run - 148.63 GFLOPS

Vulkan translation:

  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     40 runs - 25415.75 us/run - 137.42 GFLOP/run -   5.41 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               17204 runs -    59.07 us/run - 133.69 MFLOP/run -   2.26 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               13266 runs -    79.34 us/run - 135.78 MFLOP/run -   1.71 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 73728 runs -    14.34 us/run - 642.82 kFLOP/run -  44.83 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                33502 runs -    30.10 us/run -  20.90 MFLOP/run - 694.26 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                49152 runs -    21.95 us/run -   2.78 MFLOP/run - 126.89 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    93.85 us/run -  22.28 MFLOP/run - 237.39 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               19941 runs -    51.74 us/run - 115.40 MFLOP/run -   2.23 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   320.55 us/run - 923.24 MFLOP/run -   2.88 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f32,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     2145 runs -   473.24 us/run -   1.85 GFLOP/run -   3.91 TFLOPS
  CONV_2D(ne_input=[19,19,256,16],ne_kernel=[4,4,256,4096],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     37 runs - 27553.95 us/run - 137.42 GFLOP/run -   4.99 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               17204 runs -    59.35 us/run - 133.69 MFLOP/run -   2.25 TFLOPS
  CONV_2D(ne_input=[19,19,8,16],ne_kernel=[4,4,8,130],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               12529 runs -    84.34 us/run - 135.78 MFLOP/run -   1.61 TFLOPS
  CONV_2D(ne_input=[19,19,4,16],ne_kernel=[2,2,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 73728 runs -    14.60 us/run - 642.82 kFLOP/run -  44.03 GFLOPS
  CONV_2D(ne_input=[224,224,3,1],ne_kernel=[3,3,3,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                33502 runs -    31.19 us/run -  20.90 MFLOP/run - 670.03 GFLOPS
  CONV_2D(ne_input=[224,224,1,1],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                49152 runs -    22.30 us/run -   2.78 MFLOP/run - 124.87 GFLOPS
  CONV_2D(ne_input=[224,224,1,8],ne_kernel=[2,2,1,8],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                13467 runs -    98.59 us/run -  22.28 MFLOP/run - 225.96 GFLOPS
  CONV_2D(ne_input=[58,58,32,1],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):               19941 runs -    51.18 us/run - 115.40 MFLOP/run -   2.25 TFLOPS
  CONV_2D(ne_input=[58,58,32,8],ne_kernel=[3,3,32,64],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                3161 runs -   319.89 us/run - 923.24 MFLOP/run -   2.89 TFLOPS
  CONV_2D(ne_input=[16,16,128,8],ne_kernel=[3,3,128,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     1980 runs -   517.53 us/run -   1.85 GFLOP/run -   3.57 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      86 runs - 11770.58 us/run -  57.98 GFLOP/run -   4.93 TFLOPS
  CONV_2D(ne_input=[192,256,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     22 runs - 45848.36 us/run - 231.90 GFLOP/run -   5.06 TFLOPS
  CONV_2D(ne_input=[384,512,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     22 runs - 45951.27 us/run - 231.88 GFLOP/run -   5.05 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    22 runs - 47405.59 us/run - 231.83 GFLOP/run -   4.89 TFLOPS
  CONV_2D(ne_input=[96,128,512,1],ne_kernel=[1,1,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     656 runs -  1532.04 us/run -   6.44 GFLOP/run -   4.20 TFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[1,1,4,4],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                 57344 runs -    19.31 us/run - 344.06 kFLOP/run -  17.82 GFLOPS
  CONV_2D(ne_input=[96,128,4,1],ne_kernel=[3,3,4,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                2688 runs -   398.35 us/run - 446.69 MFLOP/run -   1.12 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,512],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                      6 runs - 180574.00 us/run - 927.61 GFLOP/run -   5.14 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[3,3,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     12 runs - 90381.83 us/run - 463.81 GFLOP/run -   5.13 TFLOPS
  CONV_2D(ne_input=[384,512,512,1],ne_kernel=[1,1,512,256],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                     86 runs - 11879.56 us/run -  51.49 GFLOP/run -   4.33 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                    11 runs - 91580.82 us/run - 463.76 GFLOP/run -   5.06 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[1,1,256,128],type_kernel=f16,stride0=1,stride1=1,padding0=0,padding1=0,dilation0=1,dilation1=1,cwhn=0):                    74 runs - 13755.91 us/run -  51.44 GFLOP/run -   3.74 TFLOPS
  CONV_2D(ne_input=[768,1024,256,1],ne_kernel=[3,3,256,256],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     6 runs - 184542.33 us/run - 927.51 GFLOP/run -   5.03 TFLOPS
  CONV_2D(ne_input=[768,1024,128,1],ne_kernel=[3,3,128,3],type_kernel=f16,stride0=1,stride1=1,padding0=1,padding1=1,dilation0=1,dilation1=1,cwhn=0):                     114 runs -  9053.92 us/run -   5.43 GFLOP/run - 600.12 GFLOPS

Stable-diffusion decoding is also faster with the implicit conv (but it still far from im2col+gemm):

CUDA, direct conv (this PR):
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1920.19 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1561 - computing vae decode graph completed, taking 1.86s
[INFO ] stable-diffusion.cpp:2195 - latent 3 decoded, taking 1.86s
[INFO ] stable-diffusion.cpp:2199 - decode_first_stage completed, taking 5.57s

CUDA, implicit conv:
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 1920.19 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1561 - computing vae decode graph completed, taking 1.65s
[INFO ] stable-diffusion.cpp:2195 - latent 3 decoded, taking 1.65s
[INFO ] stable-diffusion.cpp:2199 - decode_first_stage completed, taking 4.93s

CUDA (im2col+gemm):
[DEBUG] ggml_extend.hpp:1550 - vae compute buffer size: 4992.19 MB(VRAM)
[DEBUG] stable-diffusion.cpp:1561 - computing vae decode graph completed, taking 1.29s
[INFO ] stable-diffusion.cpp:2195 - latent 3 decoded, taking 1.29s
[INFO ] stable-diffusion.cpp:2199 - decode_first_stage completed, taking 3.91s

@Green-Sky
Copy link

I think you all measured the execution time wrong.

Pretty sure my sd.cpp VAE tables are correct. They are just very flaky (I need to repaste the device, it drops by maybe 20% perf over the first ~5min or so)

@etasnadi
Copy link
Owner

etasnadi commented Sep 28, 2025

I think you all measured the execution time wrong.

Pretty sure my sd.cpp VAE tables are correct. They are just very flaky (I need to repaste the device, it drops by maybe 20% perf over the first ~5min or so)

The problem is not how you tested but how the implicit alg was plugged in to ggml. There was a different op added for implcit conv2d so by default sd.cpp did not use the implicit alg.

If you explicitly modified the code then it is surprising for me that the implicit alg is not significantly faster than the Vulkan translation because it uses warptiling and double buffering and both could have significant positive effect on prf.

I created a branch in my repo where both algs are added https://github.com/etasnadi/llama.cppxx/tree/conv2d-implicit and you can switch between them by setting GGML_CUDA_USE_IMPLICIT_CONV=1 or GGML_CUDA_USE_DIRECT_CONV=1 so we can make sure that we actually test the alg we want.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants