Skip to content

Commit d8bb444

Browse files
ggerganovhodlen
authored andcommitted
ci : enable -Werror for CUDA builds (ggml-org#5579)
* cmake : pass -Werror through -Xcompiler ggml-ci * make, cmake : enable CUDA errors on warnings ggml-ci
1 parent d5d6597 commit d8bb444

File tree

3 files changed

+49
-39
lines changed

3 files changed

+49
-39
lines changed

CMakeLists.txt

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,6 @@ set(THREADS_PREFER_PTHREAD_FLAG ON)
145145
find_package(Threads REQUIRED)
146146
include(CheckCXXCompilerFlag)
147147

148-
if (LLAMA_FATAL_WARNINGS)
149-
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
150-
add_compile_options(-Werror)
151-
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
152-
add_compile_options(/WX)
153-
endif()
154-
endif()
155-
156148
# enable libstdc++ assertions for debug builds
157149
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
158150
add_compile_definitions($<$<CONFIG:Debug>:_GLIBCXX_ASSERTIONS>)
@@ -747,15 +739,24 @@ function(get_flags CCID CCVER)
747739
set(GF_CXX_FLAGS ${CXX_FLAGS} PARENT_SCOPE)
748740
endfunction()
749741
742+
if (LLAMA_FATAL_WARNINGS)
743+
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU" OR CMAKE_CXX_COMPILER_ID MATCHES "Clang")
744+
list(APPEND C_FLAGS -Werror)
745+
list(APPEND CXX_FLAGS -Werror)
746+
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
747+
add_compile_options(/WX)
748+
endif()
749+
endif()
750+
750751
if (LLAMA_ALL_WARNINGS)
751752
if (NOT MSVC)
752-
set(WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
753-
set(C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
754-
-Werror=implicit-int -Werror=implicit-function-declaration)
755-
set(CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
753+
list(APPEND WARNING_FLAGS -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function)
754+
list(APPEND C_FLAGS -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmissing-prototypes
755+
-Werror=implicit-int -Werror=implicit-function-declaration)
756+
list(APPEND CXX_FLAGS -Wmissing-declarations -Wmissing-noreturn)
756757
757-
set(C_FLAGS ${WARNING_FLAGS} ${C_FLAGS})
758-
set(CXX_FLAGS ${WARNING_FLAGS} ${CXX_FLAGS})
758+
list(APPEND C_FLAGS ${WARNING_FLAGS})
759+
list(APPEND CXX_FLAGS ${WARNING_FLAGS})
759760
760761
get_flags(${CMAKE_CXX_COMPILER_ID} ${CMAKE_CXX_COMPILER_VERSION})
761762
@@ -773,6 +774,10 @@ set(CUDA_CXX_FLAGS "")
773774
if (LLAMA_CUBLAS)
774775
set(CUDA_FLAGS -use_fast_math)
775776
777+
if (LLAMA_FATAL_WARNINGS)
778+
list(APPEND CUDA_FLAGS -Werror all-warnings)
779+
endif()
780+
776781
if (LLAMA_ALL_WARNINGS AND NOT MSVC)
777782
set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
778783
if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ MK_CFLAGS += $(WARN_FLAGS) -Wshadow -Wstrict-prototypes -Wpointer-arith -Wmis
217217
MK_CXXFLAGS += $(WARN_FLAGS) -Wmissing-declarations -Wmissing-noreturn
218218

219219
ifeq ($(LLAMA_FATAL_WARNINGS),1)
220-
MK_CFLAGS += -Werror
220+
MK_CFLAGS += -Werror
221221
MK_CXXFLAGS += -Werror
222222
endif
223223

@@ -385,6 +385,9 @@ ifdef LLAMA_CUBLAS
385385
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib -L/usr/lib/wsl/lib
386386
OBJS += ggml-cuda.o
387387
MK_NVCCFLAGS += -use_fast_math
388+
ifdef LLAMA_FATAL_WARNINGS
389+
MK_NVCCFLAGS += -Werror all-warnings
390+
endif # LLAMA_FATAL_WARNINGS
388391
ifndef JETSON_EOL_MODULE_DETECT
389392
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
390393
endif # JETSON_EOL_MODULE_DETECT

ggml-cuda.cu

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -651,18 +651,18 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
651651
return a;
652652
}
653653

654-
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
655-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
656-
#pragma unroll
657-
for (int mask = 16; mask > 0; mask >>= 1) {
658-
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
659-
}
660-
return a;
661-
#else
662-
(void) a;
663-
NO_DEVICE_CODE;
664-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
665-
}
654+
//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
655+
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
656+
//#pragma unroll
657+
// for (int mask = 16; mask > 0; mask >>= 1) {
658+
// a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
659+
// }
660+
// return a;
661+
//#else
662+
// (void) a;
663+
// NO_DEVICE_CODE;
664+
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
665+
//}
666666

667667
static __device__ __forceinline__ float warp_reduce_max(float x) {
668668
#pragma unroll
@@ -672,18 +672,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
672672
return x;
673673
}
674674

675-
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
676-
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
677-
#pragma unroll
678-
for (int mask = 16; mask > 0; mask >>= 1) {
679-
x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
680-
}
681-
return x;
682-
#else
683-
(void) x;
684-
NO_DEVICE_CODE;
685-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
686-
}
675+
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
676+
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
677+
//#pragma unroll
678+
// for (int mask = 16; mask > 0; mask >>= 1) {
679+
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
680+
// }
681+
// return x;
682+
//#else
683+
// (void) x;
684+
// NO_DEVICE_CODE;
685+
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
686+
//}
687687

688688
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
689689
return b;
@@ -4641,10 +4641,12 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
46414641
const float d = (float)bq2->d * __low2float(bq8_1[ib32].ds) * 0.25f;
46424642
return d * ((0.5f + ls1) * sumi1 + (0.5f + ls2) * sumi2);
46434643
#else
4644+
(void) ksigns64;
46444645
assert(false);
46454646
return 0.f;
46464647
#endif
46474648
#else
4649+
(void) ksigns64;
46484650
assert(false);
46494651
return 0.f;
46504652
#endif

0 commit comments

Comments
 (0)