Skip to content

Move page cache via mbind to prevent cross-NUMA access #13731

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})

# 3rd party libs
option(LLAMA_CURL "llama: use libcurl to download model from an URL" ON)
option(LLAMA_NUMA "llama: use libnuma to get memory policy of the llama-bench" ON)
option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF)

# Required for relocatable CMake package
Expand Down
13 changes: 13 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ target_compile_features (llama PRIVATE cxx_std_17) # don't bump

target_link_libraries(llama PUBLIC ggml)

if (LLAMA_NUMA)
find_library(NUMA_LIB numa)
check_include_file_cxx("numa.h" HAVE_NUMA_HEADERS)
check_include_file_cxx("numaif.h" HAVE_NUMAIF_HEADERS)
if (HAVE_NUMA_HEADERS AND HAVE_NUMAIF_HEADERS AND NUMA_LIB)
target_compile_definitions(llama PRIVATE USE_LIBNUMA)
target_link_libraries(llama PRIVATE numa)
message(STATUS "libnuma found, page cache will be moved to the local node using mbind() syscall. Disable with LLAMA_NUMA=OFF")
else()
message(STATUS "Warning: NUMA headers not found - consider disabling this Warning with LLAMA_NUMA=OFF")
endif()
endif()

if (BUILD_SHARED_LIBS)
set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(llama PRIVATE LLAMA_BUILD)
Expand Down
38 changes: 38 additions & 0 deletions src/llama-mmap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
#include <cerrno>
#include <algorithm>

#ifdef USE_LIBNUMA
#include <numa.h>
#include <numaif.h>
#include <sched.h>
#endif

#ifdef __has_include
#if __has_include(<unistd.h>)
#include <unistd.h>
Expand Down Expand Up @@ -273,6 +279,27 @@ struct llama_mmap::impl {
#ifdef _POSIX_MAPPED_FILES
std::vector<std::pair<size_t, size_t>> mapped_fragments;

#ifdef USE_LIBNUMA
static void move_pages(void *addr, size_t size) {
int cpu, ret;
struct bitmask *nodemask = numa_allocate_nodemask();

/* Get memory policy of the calling thread. */
ret = get_mempolicy(nullptr, nodemask->maskp, nodemask->size, nullptr, 0);
if (ret || numa_bitmask_weight(nodemask) == 0) {
cpu = sched_getcpu();
if (cpu >= 0) {
numa_bitmask_clearall(nodemask);
numa_bitmask_setbit(nodemask, numa_node_of_cpu(cpu));
}
}
if (numa_bitmask_weight(nodemask) == 1) {
mbind(addr, size, MPOL_BIND, nodemask->maskp, nodemask->size, MPOL_MF_MOVE);
}
numa_free_nodemask(nodemask);
}
#endif

impl(struct llama_file * file, size_t prefetch, bool numa) {
size = file->size();
int fd = file->file_id();
Expand All @@ -291,6 +318,17 @@ struct llama_mmap::impl {
}

if (prefetch > 0) {
#ifdef USE_LIBNUMA
/*
* Given that we already pre-fault all memory when prefetch > 0, it is
* necessary to move any page cache pages that might have been
* instantiated during previous runs on different NUMA nodes. This call
* to move_pages() ensures that all memory-mapped pages are relocated
* according to the calling thread's memory policy or the CPU on which
* it is running.
*/
move_pages(addr, file->size());
#endif
if (posix_madvise(addr, std::min(file->size(), prefetch), POSIX_MADV_WILLNEED)) {
LLAMA_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_WILLNEED) failed: %s\n",
strerror(errno));
Expand Down