Skip to content

Commit

Permalink
metal : matrix-matrix multiplication kernel (ggerganov#2615)
Browse files Browse the repository at this point in the history
* metal: matrix-matrix multiplication kernel

This commit removes MPS and uses custom matrix-matrix multiplication
kernels for all quantization types. This commit also adds grouped-query
attention to support llama2 70B.

* metal: fix performance degradation from gqa

Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.

* metal: fix bugs for GQA and perplexity test.

I mixed up ne02 and nb02 in previous commit.
  • Loading branch information
lshzh-ww committed Aug 16, 2023
1 parent b5ffb28 commit bf83bff
Show file tree
Hide file tree
Showing 6 changed files with 528 additions and 636 deletions.
2 changes: 0 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,6 @@ if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)

set(GGML_SOURCES_METAL ggml-metal.m ggml-metal.h)

Expand All @@ -313,7 +312,6 @@ if (LLAMA_METAL)
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()

Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ endif # LLAMA_CLBLAST
ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
OBJS += ggml-metal.o
endif # LLAMA_METAL

Expand Down
2 changes: 0 additions & 2 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
with pkgs.darwin.apple_sdk_11_0.frameworks; [
Accelerate
MetalKit
MetalPerformanceShaders
MetalPerformanceShadersGraph
]
else if isAarch32 && isDarwin then
with pkgs.darwin.apple_sdk.frameworks; [
Expand Down
Loading

0 comments on commit bf83bff

Please sign in to comment.