Skip to content

Commit c47ae2d

Browse files
committed
[layers] fix assert bug when concat gate&up
1 parent a87a55b commit c47ae2d

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,9 @@ add_definitions(-DAVX512_FP32_WEIGHT_ONLY_NF4=true)
183183
# add_definitions(-DSTEP_BY_STEP_ATTN=true)
184184
add_definitions(-DUSE_SHM=true)
185185
option(XFT_BUILD_TESTS "Build xfastertransformer unit tests" OFF)
186+
if(XFT_BUILD_TESTS)
187+
add_definitions(-DUNDEBUG=true)
188+
endif()
186189

187190
# timeline event
188191
option(WITH_TIMELINE "Build with timeline event support" OFF)

src/layers/mlp_llama.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
// limitations under the License.
1414
// ============================================================================
1515
#pragma once
16+
17+
#ifdef UNDEBUG
18+
#undef NDEBUG
19+
#endif
20+
1621
#include "bert_util.h"
1722
#include "copy_util.h"
1823
#include "debugger.h"
@@ -248,10 +253,7 @@ class LlamaMLP : public SingletonBase<LlamaMLP<WeiT>> {
248253
TimeLine t("DownProj");
249254

250255
assert(input.Rows() == output.Rows());
251-
if (!enableCATMLP())
252-
assert(input.Cols() == downWeight.Rows());
253-
else
254-
assert(input.Cols() == 2 * downWeight.Rows());
256+
assert(input.Cols() == downWeight.Rows());
255257
assert(downWeight.Cols() == output.Cols());
256258

257259
int M = input.Rows(), N = output.Cols(), K = downWeight.Rows();

0 commit comments

Comments
 (0)