Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/kf/ThreadPool.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace kf
}

private:
Thread m_threads[kMaxCount];
Thread m_threads[kMaxCount]{};
int m_count;
};
}
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ wdk_add_driver(kf-test WINVER NTDDI_WIN10 STL
AutoSpinLockTest.cpp
EResourceSharedLockTest.cpp
RecursiveAutoSpinLockTest.cpp
ThreadPoolTest.cpp
)

target_link_libraries(kf-test kf::kf kmtest::kmtest)
Expand Down
120 changes: 120 additions & 0 deletions test/ThreadPoolTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#include "pch.h"
#include <kf/ThreadPool.h>

SCENARIO("kf::ThreadPool")
{
struct TestObject
{
NTSTATUS run()
{
InterlockedIncrement(&value);
return STATUS_SUCCESS;
}

LONG value = 0;
};

constexpr auto fn = [](void* context) {
LARGE_INTEGER interval;
interval.QuadPart = -10'000;
KeDelayExecutionThread(KernelMode, FALSE, &interval);
auto p = static_cast<LONG*>(context);
InterlockedIncrement(p);
};
Copy link

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number -10'000 should be documented or replaced with a named constant. This appears to be a 1ms delay in 100ns units, but this is not immediately clear.

Suggested change
};
// 1 millisecond delay in 100ns units (negative for relative time)
constexpr LONGLONG ONE_MILLISECOND_DELAY_100NS_UNITS = -10'000;
constexpr auto fn = [](void* context) {
LARGE_INTEGER interval;
interval.QuadPart = ONE_MILLISECOND_DELAY_100NS_UNITS;
KeDelayExecutionThread(KernelMode, FALSE, &interval);
auto p = static_cast<LONG*>(context);
InterlockedIncrement(p);
};

Copilot uses AI. Check for mistakes.

GIVEN("A ThreadPool with count <= kMaxCount")
{
kf::ThreadPool pool(4);

WHEN("Starting threads with a lambda")
{
LONG value = 0;
NTSTATUS status = pool.start(fn, &value);

THEN("Status is successful")
{
REQUIRE(NT_SUCCESS(status));
}

pool.join();

THEN("All threads incremented the value")
{
REQUIRE(value == 4);
}
}

WHEN("Starting threads with a member function")
{
TestObject obj;
NTSTATUS status = pool.start<&TestObject::run>(&obj);

THEN("Status is successful")
{
REQUIRE(NT_SUCCESS(status));
}

pool.join();

THEN("All threads executed the member routine")
{
REQUIRE(obj.value == 4);
}
}
}

GIVEN("A ThreadPool with count > kMaxCount")
{
kf::ThreadPool pool(100);
LONG value = 0;

WHEN("Starting threads")
{
REQUIRE_NT_SUCCESS(pool.start(fn, &value));

pool.join();

THEN("ThreadPool started up to kMaxCount threads")
{
REQUIRE(value == 64);
Copy link

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic number 64 should be replaced with a reference to kMaxCount to make the test more maintainable. If kMaxCount changes, this test would fail unexpectedly.

Suggested change
REQUIRE(value == 64);
REQUIRE(value == kf::ThreadPool::kMaxCount);

Copilot uses AI. Check for mistakes.
}
}
}

GIVEN("A ThreadPool with 4 threads")
{
LONG value = 0;

WHEN("ThreadPool with started threads goes out of scope")
{
{
kf::ThreadPool pool(4);
REQUIRE_NT_SUCCESS(pool.start(fn, &value));
}

THEN("All threads are complete successfully")
{
REQUIRE(value == 4);
}
}
}

GIVEN("ThreadPool with 4 started threads")
{
LONG value = 0;
kf::ThreadPool pool1(4);

REQUIRE_NT_SUCCESS(pool1.start(fn, &value));

WHEN("ThreadPool moved into another pool")
{
kf::ThreadPool pool2(std::move(pool1));
pool2.join();

THEN("Threads still run and complete successfully")
{
REQUIRE(value == 4);
}
}
}
}
36 changes: 36 additions & 0 deletions test/pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,42 @@ extern "C" inline int _CrtDbgReport(
KeBugCheckEx(KERNEL_SECURITY_CHECK_FAILURE, 0, 0, 0, 0);
}

inline void __ehvec_dtor(
void* ptr,
unsigned __int64 size,
unsigned __int64 count,
void(__cdecl* dtor)(void*)
)
{
UNREFERENCED_PARAMETER(ptr);
UNREFERENCED_PARAMETER(size);
UNREFERENCED_PARAMETER(count);
UNREFERENCED_PARAMETER(dtor);
}

inline void __cdecl __ehvec_copy_ctor(
void* dst,
void* src,
unsigned __int64 size,
unsigned __int64 count,
void(__cdecl* copy_ctor)(void*, void*),
void(__cdecl* dtor)(void*)
)
{
UNREFERENCED_PARAMETER(dtor);

auto d = static_cast<unsigned char*>(dst);
auto s = static_cast<unsigned char*>(src);

for (unsigned __int64 i = 0; i < count; ++i)
{
copy_ctor(d, s);
d += size;
s += size;
}
}


namespace std
{
[[noreturn]] inline void __cdecl _Xinvalid_argument(_In_z_ const char* /*What*/)
Expand Down