Skip to content

Bug: CUDA illegal memory access related to KV/n_ctx padding and F16 DMMV #8798

Closed
@cebtenzzre

Description

@cebtenzzre

I'm not sure why I can't reproduce this with llama-cli, but I can reproduce it with GPT4All after the merge of PR #7257, up to and including commit 398ede5 from today (the latest I've tried).

edit: I can also reproduce this on commit 952d03d from before the padding was increased, so the extra padding for FA seems to have been masking an older bug.

Diagnostic information is given for a fork based on commit 398ede5, but line numbers won't match exactly in ggml-cuda.cu due to some extra code added for device enumeration, which is required by GPT4All.

cc @slaren @JohannesGaessler

Steps to reproduce

  1. Construct a llama-2-7b.Q4_0.gguf model fully offloaded to a single Tesla P40, with n_ctx=2016 (a multiple of 32 but not 256), n_batch=2048, and n_ubatch=512. Flash attention is disabled.
  2. In chunks of 128 (the max batch size GPT4All uses in practice), decode 1990 tokens of input.
  3. Sample a token.
  4. Decode the sampled token with n_past=1990. At this point, CUDA will hit an illegal memory access, which will be reported in the next synchronize call:
pos=1989 13 '
'
sampling token n_past=1990
decode(n_past=1990, n_eval=1):
pos=1990 13 '
'
sampling token n_past=1991
CUDA error: an illegal memory access was encountered
  current device: 0, in function ggml_backend_cuda_synchronize at /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:2408
  cudaStreamSynchronize(cuda_ctx->stream())
/home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:102: CUDA error

This is the first error reported by Compute Sanitizer:

========= Invalid __global__ read of size 2 bytes
=========     at __half::operator float() const+0x358 in /opt/cuda/targets/x86_64-linux/include/cuda_fp16.hpp:136
=========     by thread (16,0,0) in block (127,0,0)
=========     Address 0x782d9d000000 is out of bounds
=========     and is 1 bytes after the nearest allocation at 0x782d5e000000 of size 1,056,964,608 bytes
=========     Device Frame:convert_f16(const void *, long, int, float2 &)+0x2c8 in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda/dmmv.cu:421
=========     Device Frame:void dequantize_mul_mat_vec<(ggml_type)1>(const void *, const float *, float *, int, int)+0x2c8 in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda/dmmv.cu:474
=========     Saved host backtrace up to driver entry point at kernel launch time
=========     Host Frame: [0x2c914f]
=========                in /usr/lib/libcuda.so.1
=========     Host Frame: [0x15803]
=========                in /opt/cuda/targets/x86_64-linux/lib/libcudart.so.12
=========     Host Frame:cudaLaunchKernel [0x75230]
=========                in /opt/cuda/targets/x86_64-linux/lib/libcudart.so.12
=========     Host Frame:cudaError cudaLaunchKernel<char>(char const*, dim3, dim3, void**, unsigned long, CUstream_st*) in /opt/cuda/targets/x86_64-linux/include/cuda_runtime.h:216 [0x439658]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:__device_stub__Z22dequantize_mul_mat_vecIL9ggml_type1EEvPKvPKfPfii(void const*, float const*, float*, int, int) in /tmp/tmpxft_000138b8_00000000-6_dmmv.compute_61.cudafe1.stub.c:29 [0x438ef3]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:void __wrapper__device_stub_dequantize_mul_mat_vec<(ggml_type)1>(void const* restrict&, float const* restrict&, float* restrict&, int const&, int const&) in /tmp/tmpxft_000138b8_00000000-6_dmmv.compute_61.cudafe1.stub.c:30 [0x438f55]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:void dequantize_mul_mat_vec<(ggml_type)1>(void const*, float const*, float*, int, int) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda/dmmv.cu:436 [0x439607]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:convert_mul_mat_vec_f16_cuda(void const*, float const*, float*, int, int, CUstream_st*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda/dmmv.cu:595 [0x437683]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_cuda_op_dequantize_mul_mat_vec(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda/dmmv.cu:662 [0x43793c]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_cuda_op_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, void (*)(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*, char const*, float const*, char const*, float*, long, long, long, long, CUstream_st*), void (*)(float const*, void*, long, long, long, long, ggml_type, CUstream_st*)) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:1618 [0x3f39da]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_cuda_mul_mat(ggml_backend_cuda_context&, ggml_tensor const*, ggml_tensor const*, ggml_tensor*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:1949 [0x3f58cc]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_cuda_compute_forward(ggml_backend_cuda_context&, ggml_tensor*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:2247 [0x3f6d5e]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_backend_cuda_graph_compute(ggml_backend*, ggml_cgraph*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-cuda.cu:2608 [0x3f7b77]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_backend_graph_compute_async in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-backend.c:282 [0x3d539f]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_backend_sched_compute_splits in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-backend.c:1790 [0x3d8285]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:ggml_backend_sched_graph_compute_async in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/ggml/src/ggml-backend.c:1977 [0x3d847b]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:llama_graph_compute(llama_context&, ggml_cgraph*, int) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/src/llama.cpp:14504 [0x2c05c4]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:llama_decode_internal(llama_context&, llama_batch) in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/src/llama.cpp:14717 [0x2e0f0f]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:llama_decode in /home/jared/src/forks/gpt4all/gpt4all-backend/llama.cpp-mainline/src/llama.cpp:18429 [0x2e1668]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:LLamaModel::evalTokens(LLModel::PromptContext&, std::vector<int, std::allocator<int> > const&) const in /home/jared/src/forks/gpt4all/gpt4all-backend/llamamodel.cpp:608 [0x2916cf]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllamamodel-mainline-cuda.so
=========     Host Frame:LLModel::generateResponse(std::function<bool (int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)>, std::function<bool (bool)>, LLModel::PromptContext&) in /home/jared/src/forks/gpt4all/gpt4all-backend/llmodel_shared.cpp:263 [0xc3af8]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllmodel.so.0
=========     Host Frame:LLModel::prompt(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, std::function<bool (int)>, std::function<bool (int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)>, std::function<bool (bool)>, LLModel::PromptContext&, bool, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >*) in /home/jared/src/forks/gpt4all/gpt4all-backend/llmodel_shared.cpp:169 [0xc4c6a]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/libllmodel.so.0
=========     Host Frame:ChatLLM::promptInternal(QList<QString> const&, QString const&, QString const&, int, int, float, float, float, int, float, int) in /home/jared/src/forks/gpt4all/gpt4all-chat/chatllm.cpp:803 [0x88573]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/chat
=========     Host Frame:ChatLLM::prompt(QList<QString> const&, QString const&) in /home/jared/src/forks/gpt4all/gpt4all-chat/chatllm.cpp:745 [0x88988]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/chat
=========     Host Frame:QtPrivate::FunctorCall<QtPrivate::IndexesList<0, 1>, QtPrivate::List<QList<QString> const&, QString const&>, void, bool (ChatLLM::*)(QList<QString> const&, QString const&)>::call(bool (ChatLLM::*)(QList<QString> const&, QString const&), ChatLLM*, void**) [0x7e8c3]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/chat
=========     Host Frame:void QtPrivate::FunctionPointer<bool (ChatLLM::*)(QList<QString> const&, QString const&)>::call<QtPrivate::List<QList<QString> const&, QString const&>, void>(bool (ChatLLM::*)(QList<QString> const&, QString const&), ChatLLM*, void**) [0x7e8fa]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/chat
=========     Host Frame:QtPrivate::QCallableObject<bool (ChatLLM::*)(QList<QString> const&, QString const&), QtPrivate::List<QList<QString> const&, QString const&>, void>::impl(int, QtPrivate::QSlotObjectBase*, QObject*, void**, bool*) [0x7e951]
=========                in /home/jared/src/forks/gpt4all/gpt4all-chat/build/bin/chat
=========     Host Frame:QObject::event(QEvent*) in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qobject.cpp:1452 [0x18c00e]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:QCoreApplication::notifyInternal2(QObject*, QEvent*) in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qcoreapplication.cpp:1142 [0x144d27]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:QCoreApplicationPrivate::sendPostedEvents(QObject*, int, QThreadData*) in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qcoreapplication.cpp:1940 [0x1450ea]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:postEventSourceDispatch in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qeventdispatcher_glib.cpp:244 [0x3a49eb]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:g_main_dispatch in ../glib/glib/gmain.c:3344 [0x5ca88]
=========                in /usr/lib/libglib-2.0.so.0
=========     Host Frame:g_main_context_iterate_unlocked.isra.0 in ../glib/glib/gmain.c:4217 [0xbe9b6]
=========                in /usr/lib/libglib-2.0.so.0
=========     Host Frame:g_main_context_iteration in ../glib/glib/gmain.c:4282 [0x5bf94]
=========                in /usr/lib/libglib-2.0.so.0
=========     Host Frame:QEventDispatcherGlib::processEvents(QFlags<QEventLoop::ProcessEventsFlag>) in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qeventdispatcher_glib.cpp:394 [0x3a2cbc]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:QEventLoop::exec(QFlags<QEventLoop::ProcessEventsFlag>) in /usr/src/debug/qt6-base/qtbase/src/corelib/kernel/qeventloop.cpp:182 [0x14f01d]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:QThread::run() in /usr/src/debug/qt6-base/qtbase/src/corelib/thread/qthread.cpp:707 [0x23a55f]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:QThreadPrivate::start(void*) in /usr/src/debug/qt6-base/qtbase/src/corelib/thread/qthread_unix.cpp:285 [0x2c9746]
=========                in /usr/lib/libQt6Core.so.6
=========     Host Frame:start_thread in /usr/src/debug/glibc/glibc/nptl/pthread_create.c:447 [0x92dec]
=========                in /usr/lib/libc.so.6
=========     Host Frame:clone3 in ../sysdeps/unix/sysv/linux/x86_64/clone3.S:78 [0x1160db]
=========                in /usr/lib/libc.so.6

The full report from Compute Sanitizer is here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghigh severityUsed to report high severity bugs in llama.cpp (Malfunctioning hinder important workflow)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions