Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions csrc/radix_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,45 @@
#include <memory>
#include <type_traits>
#include <algorithm>
#include <optional>

#include "cache_utils.h"
#include "radix_tree.h"

namespace flexkv {

// Helper function matching Python's get_hash with boundary check and _has_hashes branch
// Returns std::nullopt if block_id is out of bounds (like Python returning None)
// If has_hashes is true, reads from block_hashes_ptr; otherwise computes from token_ids
static std::optional<HashType> get_hash_safe(
int64_t *block_hashes_ptr,
int64_t *token_ids_ptr, // Can be nullptr if has_hashes is true
int block_id,
int num_blocks,
bool has_hashes,
int tokens_per_block) {
if (block_id >= num_blocks) {
return std::nullopt; // Out of bounds, return None (similar to Python)
}

if (has_hashes) {
// Read from pre-computed block_hashes (matching Python: if self._has_hashes)
return HashType(block_hashes_ptr[block_id]);
} else {
// Compute hash from token_ids (matching Python: hash_array(self.token_ids[...]))
if (token_ids_ptr == nullptr) {
// Cannot compute without token_ids, return nullopt
return std::nullopt;
}
// Compute hash for tokens up to (block_id+1)*tokens_per_block
// Matching Python: hash_array(self.token_ids[:(block_id+1)*self.tokens_per_block])
Hasher hasher;
hasher.reset(); // Reset hasher (matching Python's _HASHER.reset())
hasher.update(token_ids_ptr, (block_id + 1) * tokens_per_block * sizeof(int64_t));
return hasher.digest();
}
}

CRadixNode::CRadixNode(CRadixTreeIndex *index, bool ready, int lock_cnt) {
assert(index != nullptr);

Expand Down Expand Up @@ -202,6 +235,7 @@ int CRadixTreeIndex::evict(torch::Tensor &evicted_blocks, int num_evicted) {
return has_evicted;
}


std::shared_ptr<CMatchResult> CRadixTreeIndex::match_prefix(
torch::Tensor &block_hashes, int num_blocks, bool update_cache_info) {
auto current_node = root;
Expand All @@ -212,14 +246,27 @@ std::shared_ptr<CMatchResult> CRadixTreeIndex::match_prefix(
auto physical_blocks = new std::vector<int64_t>();
auto block_hashes_ptr = block_hashes.data_ptr<int64_t>();
HashType child_hash;

// In C++ version, block_hashes is always pre-computed (has_hashes = true)
// token_ids_ptr is nullptr since we don't have token_ids in this function signature
bool has_hashes = true;
int64_t *token_ids_ptr = nullptr;

while (prefix_blocks_num < num_blocks) {
if (update_cache_info) {
current_node->update_time();
}

child_hash = HashType(block_hashes_ptr[prefix_blocks_num + current_node->size()]);
if (current_node->lookup_child(child_hash)) {
// Use get_hash_safe (matching Python's get_hash with boundary check and _has_hashes branch)
auto child_hash_opt = get_hash_safe(
block_hashes_ptr,
token_ids_ptr,
prefix_blocks_num + current_node->size(),
num_blocks,
has_hashes,
tokens_per_block);
if (child_hash_opt.has_value() && current_node->lookup_child(child_hash_opt.value())) {
child_hash = child_hash_opt.value();
if (current_node->is_ready()) {
last_ready_node = current_node;
ready_prefix_blocks_num += current_node->size();
Expand All @@ -237,7 +284,15 @@ std::shared_ptr<CMatchResult> CRadixTreeIndex::match_prefix(

while (left < right) {
auto mid = (left + right) / 2;
if (current_node->get_hash(mid) == HashType(block_hashes_ptr[prefix_blocks_num+mid])) {
// Use get_hash_safe for boundary check (matching Python's get_hash with _has_hashes branch)
auto hash_opt = get_hash_safe(
block_hashes_ptr,
token_ids_ptr,
prefix_blocks_num + mid,
num_blocks,
has_hashes,
tokens_per_block);
if (hash_opt.has_value() && current_node->get_hash(mid) == hash_opt.value()) {
left = mid + 1;
} else {
right = mid;
Expand Down