Skip to content

Commit 073055c

Browse files
committed
[layers] fix assert bug when concat gate&up
1 parent a87a55b commit 073055c

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ add_definitions(-DAVX512_FP32_WEIGHT_ONLY_NF4=true)
183183
# add_definitions(-DSTEP_BY_STEP_ATTN=true)
184184
add_definitions(-DUSE_SHM=true)
185185
option(XFT_BUILD_TESTS "Build xfastertransformer unit tests" OFF)
186+
if(XFT_BUILD_TESTS)
187+
add_definitions(-DUNDEBUG=true)
188+
endif()
186189

187190
# timeline event
188191
option(WITH_TIMELINE "Build with timeline event support" OFF)

src/layers/decoder_layer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
#include <immintrin.h>
1818
#include <omp.h>
1919

20+
#ifdef UNDEBUG
21+
#undef NDEBUG
22+
#endif
2023
#include <cassert>
24+
2125
#include <cmath>
2226
#include <cstdarg>
2327
#include <cstdio>

src/layers/mlp_llama.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,10 +248,8 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
248248
TimeLine t("DownProj");
249249

250250
assert(input.Rows() == output.Rows());
251-
if (!enableCATMLP())
252-
assert(input.Cols() == downWeight.Rows());
253-
else
254-
assert(input.Cols() == 2 * downWeight.Rows());
251+
assert(input.Cols() == downWeight.Rows());
252+
assert(input.Cols() == 2 * downWeight.Rows());
255253
assert(downWeight.Cols() == output.Cols());
256254

257255
int M = input.Rows(), N = output.Cols(), K = downWeight.Rows();

0 commit comments

Comments
 (0)