Skip to content

Commit b4a4193

Browse files
committed
wqupdate xdnn and pack to support fp8 gemm in prefill
1 parent 603bb72 commit b4a4193

File tree

4 files changed

+13
-9
lines changed

4 files changed

+13
-9
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ Please install libnuma package:
153153
git checkout <latest-tag>
154154
# Please make sure torch is installed when run python example
155155
mkdir build && cd build
156+
# Notice: use gcc-13 or higher
156157
cmake ..
158+
# If you see the error "numa.h: No such file or directory", install libnuma first, then build with "CPATH=$CONDA_PATH/include/:$CPATH make -j".
157159
make -j
158160
```
159161
- Using `python setup.py`

README_CN.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ docker run -it \
154154
git checkout <latest-tag>
155155
# 如果使用python示例,请确保已经安装torch。
156156
mkdir build && cd build
157+
# 注意使用gcc-13及以上版本
157158
cmake ..
159+
# 若遇到错误 "numa.h: No such file or directory",需要先安装numa包,然后使用 "CPATH=$CONDA_PATH/include/:$CPATH make -j"完成编译
158160
make -j
159161
```
160162
- 使用 `python setup.py`

cmake/xdnn.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ include(ExternalProject)
2626

2727
# cmake-format: off
2828
ExternalProject_Add(xdnn_lib
29-
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.7.tar.gz
30-
URL_HASH MD5=6cad71df05ef120e058bce28a0a478a8
29+
URL https://github.com/intel/xFasterTransformer/releases/download/IntrinsicGemm/xdnn_v1.5.9.tar.gz
30+
URL_HASH MD5=3aa9cd15df3eb2a7a1c178f3edcf9d37
3131
TIMEOUT 120
3232
SOURCE_DIR ${CMAKE_SOURCE_DIR}/3rdparty/xdnn
3333
CONFIGURE_COMMAND ""

src/utils/matmul_helper.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -524,12 +524,12 @@ class MMHelper {
524524

525525
// E4M3
526526
else if constexpr (std::is_same_v<WeiT, e4m3_t>) {
527-
int amx_rows = (int)((K + 15) / 16) * 16;
528-
int amx_cols = (int)((N + 63) / 64) * 64;
529-
if (!weight.isShadow()) weight.Resize(amx_rows, amx_cols);
530-
memset(weight.Data(), 0, sizeof(e4m3_t) * amx_rows * amx_cols);
527+
int blockSize = 32;
528+
size_t pack_size = xdnn_small_amx_sgemm_bf16f8bf16_packb_size(K, N, blockSize);
529+
if (!weight.isShadow()) weight.Resize((pack_size + N - 1) / N, N);
530+
memset(weight.Data(), 0, sizeof(e4m3_t) * pack_size);
531531
xdnn_small_amx_sgemm_bf16f8bf16_packb(trans, N, K, (const XDNN_E4M3 *)src.Data(), src.Stride(),
532-
(XDNN_E4M3 *)weight.Data(), 64);
532+
(XDNN_E4M3 *)weight.Data(), blockSize);
533533
}
534534
}
535535

@@ -691,7 +691,7 @@ class MMHelper {
691691

692692
// E4M3
693693
else if constexpr (std::is_same_v<WeiT, e4m3_t>) {
694-
if (M <= 16) {
694+
if (true) {
695695
assert(blockSize == 128);
696696
if (lds == -1) lds = (K + 127) / 128;
697697
GEMMVERBOSE("xdnn_gemm_bf16f8bf16_compute",
@@ -1509,7 +1509,7 @@ class MMHelper {
15091509

15101510
// E4M3
15111511
else if constexpr (std::is_same_v<WeiT, e4m3_t>) {
1512-
if (M <= 16) {
1512+
if (true) {
15131513
assert(blockSize == 128);
15141514
if (lds == -1) lds = (K + 127) / 128;
15151515
GEMMVERBOSE("xdnn_gemm_bf16f8bf16_compute_residential",

0 commit comments

Comments
 (0)