Skip to content

Commit

Permalink
feat: add SYCL Backend Support for Intel GPUs (leejet#330)
Browse files Browse the repository at this point in the history
* update ggml and add SYCL CMake option

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* hacky CMakeLists.txt for updating ggml in cpu backend

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* rebase and clean code

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* add sycl in README

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* rebase ggml commit

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* refine README

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

* update ggml for supporting sycl tsembd op

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>

---------

Signed-off-by: zhentaoyu <zhentao.yu@intel.com>
  • Loading branch information
zhentaoyu authored and SkutteOleg committed Aug 21, 2024
1 parent e91ce4f commit 2cbc146
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 10 deletions.
11 changes: 9 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,20 @@ option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE})
option(SD_CUBLAS "sd: cuda backend" OFF)
option(SD_HIPBLAS "sd: rocm backend" OFF)
option(SD_METAL "sd: metal backend" OFF)
option(SD_SYCL "sd: sycl backend" OFF)
option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF)
option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF)
option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF)
#option(SD_BUILD_SERVER "sd: build server example" ON)

if(SD_CUBLAS)
message("Use CUBLAS as backend stable-diffusion")
message("Use CUBLAS as backend stable-diffusion")
set(GGML_CUDA ON)
add_definitions(-DSD_USE_CUBLAS)
endif()

if(SD_METAL)
message("Use Metal as backend stable-diffusion")
message("Use Metal as backend stable-diffusion")
set(GGML_METAL ON)
add_definitions(-DSD_USE_METAL)
endif()
Expand All @@ -53,6 +54,12 @@ if (SD_HIPBLAS)
endif()
endif ()

if(SD_SYCL)
message("Use SYCL as backend stable-diffusion")
set(GGML_SYCL ON)
add_definitions(-DSD_USE_SYCL)
endif()

if(SD_FLASH_ATTN)
message("Use Flash Attention for memory optimization")
add_definitions(-DSD_USE_FLASH_ATTENTION)
Expand Down
33 changes: 32 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in
- Accelerated memory-efficient CPU inference
- Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB.
- AVX, AVX2 and AVX512 support for x86 architectures
- Full CUDA and Metal backend for GPU acceleration.
- Full CUDA, Metal and SYCL backend for GPU acceleration.
- Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models
- No need to convert to `.ggml` or `.gguf` anymore!
- Flash Attention for memory usage optimization (only cpu for now)
Expand Down Expand Up @@ -142,6 +142,37 @@ cmake .. -DSD_METAL=ON
cmake --build . --config Release
```
##### Using SYCL
Using SYCL makes the computation run on the Intel GPU. Please make sure you have installed the related driver and [Intel® oneAPI Base toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) before start. More details and steps can refer to [llama.cpp SYCL backend](https://github.com/ggerganov/llama.cpp/blob/master/docs/backend/SYCL.md#linux).
```
# Export relevant ENV variables
source /opt/intel/oneapi/setvars.sh

# Option 1: Use FP32 (recommended for better performance in most cases)
cmake .. -DSD_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx

# Option 2: Use FP16
cmake .. -DSD_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON

cmake --build . --config Release
```
Example of text2img by using SYCL backend:
- download `stable-diffusion` model weight, refer to [download-weight](#download-weights).
- run `./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors --cfg-scale 5 --steps 30 --sampling-method euler -H 512 -W 512 --seed 42 -p "fantasy medieval village world inside a glass sphere , high detail, fantasy, realistic, light effect, hyper detail, volumetric lighting, cinematic, macro, depth of field, blur, red light and clouds from the back, highly detailed epic cinematic concept art cg render made in maya, blender and photoshop, octane render, excellent composition, dynamic dramatic cinematic lighting, aesthetic, very inspirational, world inside a glass sphere by james gurney by artgerm with james jean, joe fenton and tristan eaton by ross tran, fine details, 4k resolution"`
<p align="center">
<img src="./assets/sycl_sd3_output.png" width="360x">
</p>
> [!NOTE]
> Try to set smaller image height and width (for example, `-H 512 -W 512`) if you meet `Provided range is out of integer limits. Pass '-fno-sycl-id-queries-fit-in-int' to disable range check.`
##### Using Flash Attention
Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing.
Expand Down
Binary file added assets/sycl_sd3_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/photo_maker.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ If on low memory GPUs (<= 8GB), recommend running with ```--vae-on-cpu``` option
Example:

```bash
bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png
bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png
```
2 changes: 1 addition & 1 deletion ggml
Submodule ggml updated 95 files
+2 −0 CMakeLists.txt
+15 −10 README.md
+6 −6 examples/gpt-2/README.md
+2 −9 examples/gpt-j/README.md
+2 −2 examples/magika/README.md
+3 −17 examples/mnist/README.md
+16 −1 examples/mnist/mnist-cnn.py
+2 −10 examples/sam/README.md
+43 −46 ggml/src/ggml-cann.cpp
+125 −0 include/ggml-cann.h
+3 −0 include/ggml-cuda.h
+2 −0 include/ggml-metal.h
+6 −4 include/ggml.h
+1 −0 requirements.txt
+7 −3 scripts/sync-llama-am.sh
+1 −1 scripts/sync-llama.last
+3 −2 scripts/sync-llama.sh
+14 −2 scripts/sync-whisper-am.sh
+6 −1 scripts/sync-whisper.sh
+56 −12 src/CMakeLists.txt
+16 −14 src/ggml-aarch64.c
+15 −10 src/ggml-backend.c
+2,020 −0 src/ggml-cann.cpp
+2,579 −0 src/ggml-cann/Doxyfile
+175 −0 src/ggml-cann/acl_tensor.cpp
+258 −0 src/ggml-cann/acl_tensor.h
+3,082 −0 src/ggml-cann/aclnn_ops.cpp
+592 −0 src/ggml-cann/aclnn_ops.h
+282 −0 src/ggml-cann/common.h
+33 −0 src/ggml-cann/kernels/CMakeLists.txt
+19 −0 src/ggml-cann/kernels/ascendc_kernels.h
+223 −0 src/ggml-cann/kernels/dup.cpp
+186 −0 src/ggml-cann/kernels/get_row_f16.cpp
+180 −0 src/ggml-cann/kernels/get_row_f32.cpp
+193 −0 src/ggml-cann/kernels/get_row_q4_0.cpp
+191 −0 src/ggml-cann/kernels/get_row_q8_0.cpp
+208 −0 src/ggml-cann/kernels/quantize_f16_q8_0.cpp
+206 −0 src/ggml-cann/kernels/quantize_f32_q8_0.cpp
+278 −0 src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
+5 −1 src/ggml-common.h
+53 −32 src/ggml-cuda.cu
+12 −194 src/ggml-cuda/common.cuh
+15 −6 src/ggml-cuda/dmmv.cu
+2 −0 src/ggml-cuda/dmmv.cuh
+6 −3 src/ggml-cuda/norm.cu
+14 −0 src/ggml-cuda/vendors/cuda.h
+177 −0 src/ggml-cuda/vendors/hip.h
+171 −0 src/ggml-cuda/vendors/musa.h
+4 −6 src/ggml-impl.h
+54 −21 src/ggml-metal.m
+8 −8 src/ggml-quants.c
+4 −0 src/ggml-quants.h
+16 −0 src/ggml-sycl.cpp
+2 −0 src/ggml-sycl/backend.hpp
+99 −0 src/ggml-sycl/conv.cpp
+21 −0 src/ggml-sycl/conv.hpp
+16 −3 src/ggml-sycl/dpct/helper.hpp
+1 −1 src/ggml-sycl/mmvq.cpp
+6 −3 src/ggml-sycl/norm.cpp
+2 −0 src/ggml-sycl/presets.hpp
+71 −0 src/ggml-sycl/tsembd.cpp
+21 −0 src/ggml-sycl/tsembd.hpp
+605 −235 src/ggml-vulkan.cpp
+79 −11 src/ggml.c
+0 −3,149 src/ggml_vk_generate_shaders.py
+2 −0 src/vulkan-shaders/CMakeLists.txt
+4 −2 src/vulkan-shaders/add.comp
+5 −3 src/vulkan-shaders/clamp.comp
+35 −0 src/vulkan-shaders/concat.comp
+5 −3 src/vulkan-shaders/copy.comp
+4 −2 src/vulkan-shaders/div.comp
+1 −1 src/vulkan-shaders/gelu.comp
+23 −0 src/vulkan-shaders/gelu_quick.comp
+5 −1 src/vulkan-shaders/generic_binary_head.comp
+4 −0 src/vulkan-shaders/generic_unary_head.comp
+66 −0 src/vulkan-shaders/group_norm.comp
+57 −0 src/vulkan-shaders/im2col.comp
+22 −0 src/vulkan-shaders/leaky_relu.comp
+4 −2 src/vulkan-shaders/mul.comp
+10 −3 src/vulkan-shaders/mul_mat_vec.comp
+1 −1 src/vulkan-shaders/norm.comp
+26 −0 src/vulkan-shaders/pad.comp
+1 −1 src/vulkan-shaders/relu.comp
+1 −1 src/vulkan-shaders/rms_norm.comp
+4 −2 src/vulkan-shaders/scale.comp
+1 −1 src/vulkan-shaders/silu.comp
+1 −1 src/vulkan-shaders/soft_max.comp
+5 −3 src/vulkan-shaders/square.comp
+1 −1 src/vulkan-shaders/sum_rows.comp
+21 −0 src/vulkan-shaders/tanh.comp
+41 −0 src/vulkan-shaders/timestep_embedding.comp
+2 −2 src/vulkan-shaders/types.comp
+36 −0 src/vulkan-shaders/upscale.comp
+64 −10 src/vulkan-shaders/vulkan-shaders-gen.cpp
+13 −6 tests/test-backend-ops.cpp
14 changes: 10 additions & 4 deletions ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
#include "ggml-metal.h"
#endif

#ifdef SD_USE_SYCL
#include "ggml-sycl.h"
#endif

#include "rng.hpp"
#include "util.h"

Expand Down Expand Up @@ -537,7 +541,8 @@ __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const

__STATIC_INLINE__ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx,
struct ggml_tensor* a) {
return ggml_group_norm(ctx, a, 32);
const float eps = 1e-6f; // default eps parameter
return ggml_group_norm(ctx, a, 32, eps);
}

__STATIC_INLINE__ struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx,
Expand Down Expand Up @@ -650,7 +655,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx
struct ggml_tensor* k,
struct ggml_tensor* v,
bool mask = false) {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL)
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_SYCL)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head]
#else
float d_head = (float)q->ne[0];
Expand Down Expand Up @@ -765,7 +770,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1);
}

x = ggml_group_norm(ctx, x, num_groups);
const float eps = 1e-6f; // default eps parameter
x = ggml_group_norm(ctx, x, num_groups, eps);
if (w != NULL && b != NULL) {
x = ggml_mul(ctx, x, w);
// b = ggml_repeat(ctx, b, x);
Expand All @@ -775,7 +781,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct
}

__STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) {
#ifdef SD_USE_CUBLAS
#if defined (SD_USE_CUBLAS) || defined (SD_USE_SYCL)
if (!ggml_backend_is_cpu(backend)) {
ggml_backend_tensor_get_async(backend, tensor, data, offset, size);
ggml_backend_synchronize(backend);
Expand Down
6 changes: 5 additions & 1 deletion stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,17 @@ class StableDiffusionGGML {
ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr);
backend = ggml_backend_metal_init();
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("Using SYCL backend");
backend = ggml_backend_sycl_init(0);
#endif

if (!backend) {
LOG_DEBUG("Using CPU backend");
backend = ggml_backend_cpu_init();
}
#ifdef SD_USE_FLASH_ATTENTION
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL)
#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL)
LOG_WARN("Flash Attention not supported with GPU Backend");
#else
LOG_INFO("Flash Attention enabled");
Expand Down
4 changes: 4 additions & 0 deletions upscaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct UpscalerGGML {
ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr);
backend = ggml_backend_metal_init();
#endif
#ifdef SD_USE_SYCL
LOG_DEBUG("Using SYCL backend");
backend = ggml_backend_sycl_init(0);
#endif

if (!backend) {
LOG_DEBUG("Using CPU backend");
Expand Down

0 comments on commit 2cbc146

Please sign in to comment.