|
15 | 15 |
|
16 | 16 | namespace cache::task { |
17 | 17 |
|
18 | | -enum State { Initial = 0, Working = 1, Finished = 2, Aborted = 3 }; |
| 18 | +enum State { Initial, Working, Finished, Aborted }; |
19 | 19 |
|
20 | | -enum Mode { Write = 1, Read = 2 }; |
| 20 | +enum Mode { Write, Read }; |
21 | 21 |
|
22 | 22 | class CacheTask; |
23 | 23 |
|
| 24 | +/** |
| 25 | + * @brief Represents a single block of data within a cache task |
| 26 | + * |
| 27 | + * Each CacheBlock corresponds to a fixed-size chunk of KV cache data that can be |
| 28 | + * independently read from or written to storage. Blocks are processed asynchronously |
| 29 | + * by worker threads and track their own state throughout the operation lifecycle. |
| 30 | + */ |
24 | 31 | class CacheBlock { |
25 | 32 | public: |
26 | | - /// @brief Constructor, calculates the SHA-256 hash of the data. |
27 | | - /// @param hash Data hash. |
28 | | - /// @param m Mode: Read or Write |
29 | | - /// @param task Corresponding CacheTask |
30 | 33 | CacheBlock(std::string hash_, const int64_t block_idx_, CacheTask *task_) |
31 | 34 | : hash(std::move(hash_)), task(task_), block_idx(block_idx_) {} |
32 | 35 |
|
33 | | - /// @brief 禁止拷贝和移动 |
34 | | - CacheBlock(CacheBlock &&other) = delete; |
35 | | - CacheBlock &operator=(CacheBlock &&other) = delete; |
36 | | - |
37 | 36 | bool ready() const { return state == State::Finished; } |
38 | 37 |
|
39 | 38 | int64_t block_idx; |
40 | | - CacheTask *task; ///< Corresponding Task |
41 | | - std::string hash; ///< Hash of the block. |
42 | | - State state{}; ///< Read/write state of the block. |
| 39 | + CacheTask *task; |
| 40 | + std::string hash; |
| 41 | + State state{}; |
43 | 42 | }; |
44 | 43 |
|
45 | | -/// @brief Task cache class, storing only cache blocks. |
| 44 | +/** |
| 45 | + * @brief Manages a collection of cache blocks for a single read or write operation |
| 46 | + * |
| 47 | + * CacheTask represents a complete cache operation request from Python code, containing |
| 48 | + * multiple blocks that are processed in parallel by worker threads. It tracks completion |
| 49 | + * status, provides thread-safe access to shared state, and manages the lifecycle of all |
| 50 | + * associated blocks through RAII-managed unique_ptr ownership. |
| 51 | + * |
| 52 | + * Thread Safety: Atomic members (num_finished_blocks, num_data_ready_blocks, completion_notified) |
| 53 | + * are lock-free for high-frequency access. The state_mutex protects page_already_list_ updates. |
| 54 | + */ |
46 | 55 | class CacheTask { |
47 | 56 | public: |
48 | | - CacheTask() = delete; |
49 | | - |
50 | | - /// @brief Constructor, determines mode from user input ('r' or 'w'). |
51 | | - /// @param hashs Hash sequence. |
52 | | - /// @param mode_str Mode string: "r" for Read, "w" for Write |
53 | 57 | CacheTask(const std::vector<std::string> &hashs, torch::Tensor kv_page_indexer, const std::string &mode_str) |
54 | | - : num_finished_blocks(0), num_data_ready_blocks(0), page_indexer(std::move(kv_page_indexer)), |
55 | | - completion_notified(false) { |
| 58 | + : num_finished_blocks(0) |
| 59 | + , num_data_ready_blocks(0) |
| 60 | + , page_indexer(std::move(kv_page_indexer)) |
| 61 | + , completion_notified(false) { |
56 | 62 |
|
57 | 63 | if (mode_str == "r") { |
58 | | - mode = Mode::Read; |
| 64 | + operation_mode = Mode::Read; |
59 | 65 | } else if (mode_str == "w") { |
60 | | - mode = Mode::Write; |
| 66 | + operation_mode = Mode::Write; |
61 | 67 | } else { |
62 | 68 | throw std::invalid_argument("Invalid mode string. Use 'r' for Read or 'w' for Write."); |
63 | 69 | } |
64 | 70 |
|
65 | 71 | blocks.reserve(hashs.size()); |
66 | | - for (int64_t idx = 0; idx < hashs.size(); ++idx) { |
67 | | - blocks.emplace_back(new CacheBlock(hashs[idx], idx, this)); |
| 72 | + int64_t idx = 0; |
| 73 | + for (const auto& hash : hashs) { |
| 74 | + blocks.emplace_back(std::make_unique<CacheBlock>(hash, idx++, this)); |
68 | 75 | } |
69 | 76 | } |
70 | 77 |
|
71 | | - ~CacheTask() { |
72 | | - for (auto block : blocks) { |
73 | | - delete block; |
74 | | - } |
75 | | - } |
76 | | - |
77 | | - /// @brief 禁止拷贝和移动 |
78 | | - CacheTask(CacheTask &&other) = delete; |
79 | | - CacheTask &operator=(CacheTask &&other) = delete; |
80 | | - |
81 | 78 | bool ready() const { return num_finished_blocks.load(std::memory_order_acquire) == blocks.size(); } |
82 | 79 |
|
83 | | - /// @brief Check if data is safe to release pages (for write mode) |
84 | | - /// For write mode: returns true when data has been copied from KV cache |
85 | | - /// For read mode: equivalent to ready() |
86 | 80 | bool data_safe() const { |
87 | | - if (mode == Mode::Write) { |
| 81 | + if (operation_mode == Mode::Write) { |
88 | 82 | return num_data_ready_blocks.load(std::memory_order_acquire) >= static_cast<int64_t>(blocks.size()); |
89 | 83 | } |
90 | | - return ready(); // For read mode, data_safe is same as ready |
| 84 | + return ready(); |
91 | 85 | } |
92 | 86 |
|
93 | 87 | bool mark_completion_notified() { return !completion_notified.exchange(true, std::memory_order_acq_rel); } |
94 | 88 |
|
95 | | - std::vector<State> state() { |
96 | | - auto ret = std::vector<State>(blocks.size()); |
97 | | - for (int32_t i = 0; i < blocks.size(); ++i) { |
98 | | - ret[i] = blocks[i]->state; |
| 89 | + std::vector<State> state() const { |
| 90 | + std::vector<State> ret; |
| 91 | + ret.reserve(blocks.size()); |
| 92 | + for (const auto& block : blocks) { |
| 93 | + ret.push_back(block->state); |
99 | 94 | } |
100 | 95 | return ret; |
101 | 96 | } |
102 | 97 |
|
103 | 98 | std::vector<int32_t> get_page_already_list() const { |
104 | | - std::lock_guard<std::mutex> lock_guard(const_cast<std::mutex &>(lock)); |
| 99 | + std::lock_guard<std::mutex> lock_guard(state_mutex); |
105 | 100 | return page_already_list; |
106 | 101 | } |
107 | 102 |
|
108 | | - // 这个指针是用来标记 task 中的数据存取位置的 |
109 | 103 | torch::Tensor page_indexer; |
110 | | - |
111 | | - std::mutex lock; ///< Task state lock |
112 | | - std::vector<CacheBlock *> blocks; ///< Blocks stored as shared_ptr |
113 | | - std::atomic<int64_t> num_finished_blocks; ///< Number of finished blocks (atomic for thread-safe reading) |
114 | | - std::atomic<int64_t> num_data_ready_blocks; ///< Number of blocks with data copied (for write mode) |
115 | | - Mode mode; ///< Read/write mode of the task. |
| 104 | + mutable std::mutex state_mutex; |
| 105 | + std::vector<std::unique_ptr<CacheBlock>> blocks; |
| 106 | + std::atomic<int64_t> num_finished_blocks; |
| 107 | + std::atomic<int64_t> num_data_ready_blocks; |
| 108 | + Mode operation_mode; |
116 | 109 | std::atomic<bool> completion_notified; |
117 | | - std::vector<int32_t> page_already_list; ///< List of page indices already persisted to disk (for write mode) |
| 110 | + std::vector<int32_t> page_already_list; |
118 | 111 | }; |
119 | 112 |
|
120 | 113 | } // namespace cache::task |
0 commit comments