Skip to content

Commit fd4dee6

Browse files
author
niushengxiao
committed
fix cache_task.h
1 parent 85db7e5 commit fd4dee6

File tree

5 files changed

+110
-94
lines changed

5 files changed

+110
-94
lines changed

src/core/cache_task.h

Lines changed: 46 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -15,106 +15,99 @@
1515

1616
namespace cache::task {
1717

18-
enum State { Initial = 0, Working = 1, Finished = 2, Aborted = 3 };
18+
enum State { Initial, Working, Finished, Aborted };
1919

20-
enum Mode { Write = 1, Read = 2 };
20+
enum Mode { Write, Read };
2121

2222
class CacheTask;
2323

24+
/**
25+
* @brief Represents a single block of data within a cache task
26+
*
27+
* Each CacheBlock corresponds to a fixed-size chunk of KV cache data that can be
28+
* independently read from or written to storage. Blocks are processed asynchronously
29+
* by worker threads and track their own state throughout the operation lifecycle.
30+
*/
2431
class CacheBlock {
2532
public:
26-
/// @brief Constructor, calculates the SHA-256 hash of the data.
27-
/// @param hash Data hash.
28-
/// @param m Mode: Read or Write
29-
/// @param task Corresponding CacheTask
3033
CacheBlock(std::string hash_, const int64_t block_idx_, CacheTask *task_)
3134
: hash(std::move(hash_)), task(task_), block_idx(block_idx_) {}
3235

33-
/// @brief 禁止拷贝和移动
34-
CacheBlock(CacheBlock &&other) = delete;
35-
CacheBlock &operator=(CacheBlock &&other) = delete;
36-
3736
bool ready() const { return state == State::Finished; }
3837

3938
int64_t block_idx;
40-
CacheTask *task; ///< Corresponding Task
41-
std::string hash; ///< Hash of the block.
42-
State state{}; ///< Read/write state of the block.
39+
CacheTask *task;
40+
std::string hash;
41+
State state{};
4342
};
4443

45-
/// @brief Task cache class, storing only cache blocks.
44+
/**
45+
* @brief Manages a collection of cache blocks for a single read or write operation
46+
*
47+
* CacheTask represents a complete cache operation request from Python code, containing
48+
* multiple blocks that are processed in parallel by worker threads. It tracks completion
49+
* status, provides thread-safe access to shared state, and manages the lifecycle of all
50+
* associated blocks through RAII-managed unique_ptr ownership.
51+
*
52+
* Thread Safety: Atomic members (num_finished_blocks, num_data_ready_blocks, completion_notified)
53+
* are lock-free for high-frequency access. The state_mutex protects page_already_list_ updates.
54+
*/
4655
class CacheTask {
4756
public:
48-
CacheTask() = delete;
49-
50-
/// @brief Constructor, determines mode from user input ('r' or 'w').
51-
/// @param hashs Hash sequence.
52-
/// @param mode_str Mode string: "r" for Read, "w" for Write
5357
CacheTask(const std::vector<std::string> &hashs, torch::Tensor kv_page_indexer, const std::string &mode_str)
54-
: num_finished_blocks(0), num_data_ready_blocks(0), page_indexer(std::move(kv_page_indexer)),
55-
completion_notified(false) {
58+
: num_finished_blocks(0)
59+
, num_data_ready_blocks(0)
60+
, page_indexer(std::move(kv_page_indexer))
61+
, completion_notified(false) {
5662

5763
if (mode_str == "r") {
58-
mode = Mode::Read;
64+
operation_mode = Mode::Read;
5965
} else if (mode_str == "w") {
60-
mode = Mode::Write;
66+
operation_mode = Mode::Write;
6167
} else {
6268
throw std::invalid_argument("Invalid mode string. Use 'r' for Read or 'w' for Write.");
6369
}
6470

6571
blocks.reserve(hashs.size());
66-
for (int64_t idx = 0; idx < hashs.size(); ++idx) {
67-
blocks.emplace_back(new CacheBlock(hashs[idx], idx, this));
72+
int64_t idx = 0;
73+
for (const auto& hash : hashs) {
74+
blocks.emplace_back(std::make_unique<CacheBlock>(hash, idx++, this));
6875
}
6976
}
7077

71-
~CacheTask() {
72-
for (auto block : blocks) {
73-
delete block;
74-
}
75-
}
76-
77-
/// @brief 禁止拷贝和移动
78-
CacheTask(CacheTask &&other) = delete;
79-
CacheTask &operator=(CacheTask &&other) = delete;
80-
8178
bool ready() const { return num_finished_blocks.load(std::memory_order_acquire) == blocks.size(); }
8279

83-
/// @brief Check if data is safe to release pages (for write mode)
84-
/// For write mode: returns true when data has been copied from KV cache
85-
/// For read mode: equivalent to ready()
8680
bool data_safe() const {
87-
if (mode == Mode::Write) {
81+
if (operation_mode == Mode::Write) {
8882
return num_data_ready_blocks.load(std::memory_order_acquire) >= static_cast<int64_t>(blocks.size());
8983
}
90-
return ready(); // For read mode, data_safe is same as ready
84+
return ready();
9185
}
9286

9387
bool mark_completion_notified() { return !completion_notified.exchange(true, std::memory_order_acq_rel); }
9488

95-
std::vector<State> state() {
96-
auto ret = std::vector<State>(blocks.size());
97-
for (int32_t i = 0; i < blocks.size(); ++i) {
98-
ret[i] = blocks[i]->state;
89+
std::vector<State> state() const {
90+
std::vector<State> ret;
91+
ret.reserve(blocks.size());
92+
for (const auto& block : blocks) {
93+
ret.push_back(block->state);
9994
}
10095
return ret;
10196
}
10297

10398
std::vector<int32_t> get_page_already_list() const {
104-
std::lock_guard<std::mutex> lock_guard(const_cast<std::mutex &>(lock));
99+
std::lock_guard<std::mutex> lock_guard(state_mutex);
105100
return page_already_list;
106101
}
107102

108-
// 这个指针是用来标记 task 中的数据存取位置的
109103
torch::Tensor page_indexer;
110-
111-
std::mutex lock; ///< Task state lock
112-
std::vector<CacheBlock *> blocks; ///< Blocks stored as shared_ptr
113-
std::atomic<int64_t> num_finished_blocks; ///< Number of finished blocks (atomic for thread-safe reading)
114-
std::atomic<int64_t> num_data_ready_blocks; ///< Number of blocks with data copied (for write mode)
115-
Mode mode; ///< Read/write mode of the task.
104+
mutable std::mutex state_mutex;
105+
std::vector<std::unique_ptr<CacheBlock>> blocks;
106+
std::atomic<int64_t> num_finished_blocks;
107+
std::atomic<int64_t> num_data_ready_blocks;
108+
Mode operation_mode;
116109
std::atomic<bool> completion_notified;
117-
std::vector<int32_t> page_already_list; ///< List of page indices already persisted to disk (for write mode)
110+
std::vector<int32_t> page_already_list;
118111
};
119112

120113
} // namespace cache::task

src/core/task_queue.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ class TaskQueue {
2525
TaskQueue &operator=(TaskQueue &&other) = delete;
2626

2727
cache::error::LMError_t submit(const std::shared_ptr<cache::task::CacheTask> &task) {
28-
for (cache::task::CacheBlock *block : task->blocks) {
29-
blocks_.push(block);
28+
for (const auto &block : task->blocks) {
29+
blocks_.push(block.get());
3030
}
3131
return cache::error::LM_SUCCESS;
3232
}

src/service/cache_service.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,16 @@ class CacheService {
269269
* read/write operations on the kvcache.
270270
*/
271271
void abort_task(const std::shared_ptr<cache::task::CacheTask> &task) {
272-
for (cache::task::CacheBlock *block : task->blocks) {
273-
abort(block);
272+
for (const auto &block : task->blocks) {
273+
abort(block.get());
274274
}
275275
}
276276

277277
/**
278278
* Notify the system to immediately abandon the subsequent execution of a Block.
279279
*/
280280
void abort(cache::task::CacheBlock *block) {
281-
std::lock_guard<std::mutex> lock(block->task->lock);
281+
std::lock_guard<std::mutex> lock(block->task->state_mutex);
282282
if (block->state == cache::task::State::Initial || block->state == cache::task::State::Working) {
283283
block->state = cache::task::State::Aborted;
284284
block->task->num_finished_blocks.fetch_add(1, std::memory_order_release);
@@ -291,7 +291,7 @@ class CacheService {
291291
void deliver(cache::task::CacheBlock *block) {
292292
auto task = block->task;
293293
{
294-
std::lock_guard<std::mutex> lock(task->lock);
294+
std::lock_guard<std::mutex> lock(task->state_mutex);
295295

296296
if (block->state == cache::task::State::Working) {
297297
block->state = cache::task::State::Finished;
@@ -343,15 +343,23 @@ inline void CacheService::finalize_task(cache::task::CacheTask *task) {
343343
std::lock_guard<std::mutex> lock(lock_);
344344
for (auto it = taskpool_.begin(); it != taskpool_.end(); ++it) {
345345
if (it->get() == task) {
346+
// IMPORTANT: After this erase, the shared_ptr may be destroyed if Python side has released it.
347+
// This is safe because:
348+
// 1. All blocks have finished (checked by task->ready())
349+
// 2. No worker threads should be accessing this task anymore
350+
// 3. The task pointer 'task' is only used for comparison and counter update
346351
taskpool_.erase(it);
347352
break;
348353
}
349354
}
350355
}
351356

352357
std::atomic<int64_t> *active_counter =
353-
(task->mode == cache::task::Mode::Read) ? &active_read_creates_ : &active_write_creates_;
358+
(task->operation_mode == cache::task::Mode::Read) ? &active_read_creates_ : &active_write_creates_;
354359
active_counter->fetch_sub(1, std::memory_order_relaxed);
360+
361+
// Note: 'task' pointer may become invalid after this point if no other references exist
362+
// on_task_finalized should not dereference 'task' beyond this point unless it maintains its own reference
355363
on_task_finalized(task);
356364
}
357365

src/service/local_cache_service.h

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,18 @@ class LocalCacheService : public CacheService {
6262

6363
storage_ = make_unique<LocalStorageEngine>(file, storage_size, num_shard, block_size_);
6464

65-
r_cpu_buffers_.resize(num_workers_);
66-
w_cpu_buffers_.resize(num_workers_);
67-
68-
for (size_t i = 0; i < num_workers_; ++i) {
69-
r_cpu_buffers_[i] = new char[block_size_];
70-
w_cpu_buffers_[i] = new char[block_size_];
65+
// Use unique_ptr for exception safety - if any allocation fails, previous allocations are automatically cleaned up
66+
r_cpu_buffers_.reserve(num_workers_);
67+
w_cpu_buffers_.reserve(num_workers_);
68+
69+
try {
70+
for (size_t i = 0; i < num_workers_; ++i) {
71+
r_cpu_buffers_.emplace_back(new char[block_size_]);
72+
w_cpu_buffers_.emplace_back(new char[block_size_]);
73+
}
74+
} catch (...) {
75+
// unique_ptr will automatically clean up already allocated buffers
76+
throw;
7177
}
7278
}
7379

@@ -82,13 +88,7 @@ class LocalCacheService : public CacheService {
8288
}
8389
}
8490

85-
for (auto &buffer : r_cpu_buffers_) {
86-
delete[] buffer;
87-
}
88-
for (auto &buffer : w_cpu_buffers_) {
89-
delete[] buffer;
90-
}
91-
91+
// unique_ptr will automatically delete the buffers
9292
r_cpu_buffers_.clear();
9393
w_cpu_buffers_.clear();
9494
}
@@ -120,7 +120,7 @@ class LocalCacheService : public CacheService {
120120

121121
protected:
122122
void on_task_finalized(cache::task::CacheTask *task) override {
123-
if (task->mode == cache::task::Mode::Write) {
123+
if (task->operation_mode == cache::task::Mode::Write) {
124124
// Try to acquire the lock, skip logging if contention occurs
125125
std::unique_lock<std::mutex> guard(log_mutex_, std::try_to_lock);
126126
if (!guard.owns_lock()) {
@@ -174,7 +174,7 @@ class LocalCacheService : public CacheService {
174174
return;
175175
}
176176

177-
if (task->mode != cache::task::Mode::Read) {
177+
if (task->operation_mode != cache::task::Mode::Read) {
178178
return;
179179
}
180180
if (active_read_creates_.load(std::memory_order_relaxed) != 0) {
@@ -239,7 +239,7 @@ class LocalCacheService : public CacheService {
239239
if (auto block = this->queue_->claim()) {
240240
if (block != nullptr) {
241241
CacheTask *task = block->task;
242-
char *cpu_buffer = (task->mode == Mode::Read) ? r_cpu_buffers_[index] : w_cpu_buffers_[index];
242+
char *cpu_buffer = (task->operation_mode == Mode::Read) ? r_cpu_buffers_[index].get() : w_cpu_buffers_[index].get();
243243
processTask(block, cpu_buffer);
244244
}
245245
}
@@ -266,7 +266,7 @@ class LocalCacheService : public CacheService {
266266
}
267267

268268
bool success = false;
269-
if (task->mode == Mode::Read) {
269+
if (task->operation_mode == Mode::Read) {
270270
success = handleReadCpu(block, cpu_buffer, page_ptr, num_of_page);
271271
} else {
272272
success = handleWriteCpu(block, cpu_buffer, page_ptr, num_of_page);
@@ -291,7 +291,7 @@ class LocalCacheService : public CacheService {
291291
}
292292

293293
{
294-
std::lock_guard<std::mutex> lock(block->task->lock);
294+
std::lock_guard<std::mutex> lock(block->task->state_mutex);
295295
if (block->state != cache::task::State::Working) {
296296
return false;
297297
}
@@ -311,7 +311,7 @@ class LocalCacheService : public CacheService {
311311

312312
bool handleWriteCpu(CacheBlock *block, char *cpu_buffer, int32_t *page_ptr, int64_t num_of_page) {
313313
{
314-
std::lock_guard<std::mutex> lock(block->task->lock);
314+
std::lock_guard<std::mutex> lock(block->task->state_mutex);
315315
if (block->state != cache::task::State::Working) {
316316
return false;
317317
}
@@ -395,8 +395,8 @@ class LocalCacheService : public CacheService {
395395
vector<thread> workers_; ///< Worker threads
396396
bool stop_; ///< Thread stop flag
397397
size_t num_workers_; ///< Number of worker threads
398-
vector<char *> r_cpu_buffers_; ///< CPU buffers for read worker
399-
vector<char *> w_cpu_buffers_; ///< CPU buffers for write worker
398+
vector<unique_ptr<char[]>> r_cpu_buffers_; ///< CPU buffers for read worker (RAII managed)
399+
vector<unique_ptr<char[]>> w_cpu_buffers_; ///< CPU buffers for write worker (RAII managed)
400400
std::atomic<uint64_t> total_written_bytes_; ///< Total bytes written to disk
401401
std::atomic<int64_t> first_write_time_ticks_; ///< First write start time in steady clock ticks
402402
std::atomic<int64_t> last_write_time_ticks_; ///< Last write completion time in steady clock ticks

src/storage/local_storage_engine.h

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,21 @@ class LocalStorageEngine : public StorageEngine {
164164
files_.resize(shard_);
165165
file_fds_.resize(shard_, -1);
166166

167-
for (size_t i = 0; i < shard_; i++) {
168-
caches_[i] = std::make_shared<LocalCacheIndex>(shard_capacity);
169-
locks_[i] = std::make_shared<std::mutex>();
167+
try {
168+
for (size_t i = 0; i < shard_; i++) {
169+
caches_[i] = std::make_shared<LocalCacheIndex>(shard_capacity);
170+
locks_[i] = std::make_shared<std::mutex>();
171+
}
172+
createOrOpenFiles(shard_storage_size);
173+
} catch (...) {
174+
// Clean up any partially opened files on exception
175+
cleanup();
176+
throw;
170177
}
171-
createOrOpenFiles(shard_storage_size);
172178
}
173179

174180
~LocalStorageEngine() override {
175-
for (size_t i = 0; i < shard_; i++) {
176-
if (files_[i].is_open()) {
177-
files_[i].close();
178-
}
179-
if (file_fds_[i] >= 0) {
180-
close(file_fds_[i]);
181-
file_fds_[i] = -1;
182-
}
183-
}
181+
cleanup();
184182
}
185183

186184
bool query(const std::string &hash) override {
@@ -251,6 +249,23 @@ class LocalStorageEngine : public StorageEngine {
251249
private:
252250
inline size_t getShard(const std::string &hash) { return std::hash<std::string>{}(hash) % shard_; }
253251

252+
// Helper function to clean up file resources
253+
void cleanup() {
254+
for (size_t i = 0; i < shard_; i++) {
255+
if (files_[i].is_open()) {
256+
try {
257+
files_[i].close();
258+
} catch (...) {
259+
// Ignore exceptions during cleanup
260+
}
261+
}
262+
if (file_fds_[i] >= 0) {
263+
close(file_fds_[i]);
264+
file_fds_[i] = -1;
265+
}
266+
}
267+
}
268+
254269
void createOrOpenFiles(size_t shard_storage_size) {
255270
for (size_t i = 0; i < shard_; i++) {
256271
std::stringstream ss;

0 commit comments

Comments
 (0)