diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 3749592a0ec71..1d8d41e013b03 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -49,13 +49,12 @@ def test_copy_blocks( src_blocks = random.sample(range(num_blocks), num_mappings) remainig_blocks = list(set(range(num_blocks)) - set(src_blocks)) dst_blocks = random.sample(remainig_blocks, 2 * num_mappings) - copy_src = [] - copy_dst = [] + block_mapping = {} for i in range(num_mappings): - copy_src.append(src_blocks[i]) - copy_dst.append(dst_blocks[2 * i]) - copy_src.append(src_blocks[i]) - copy_dst.append(dst_blocks[2 * i + 1]) + src = src_blocks[i] + dst1 = dst_blocks[2 * i] + dst2 = dst_blocks[2 * i + 1] + block_mapping[src] = [dst1, dst2] # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, @@ -67,14 +66,15 @@ def test_copy_blocks( cloned_value_caches = [value_cache.clone() for value_cache in value_caches] # Call the copy blocks kernel. - cache_ops.copy_blocks(key_caches, value_caches, copy_src, copy_dst) + cache_ops.copy_blocks(key_caches, value_caches, block_mapping) # Run the reference implementation. - for src, dst in zip(copy_src, copy_dst): - for cloned_key_cache in cloned_key_caches: - cloned_key_cache[dst].copy_(cloned_key_cache[src]) - for cloned_value_cache in cloned_value_caches: - cloned_value_cache[dst].copy_(cloned_value_cache[src]) + for src, dsts in block_mapping.items(): + for dst in dsts: + for cloned_key_cache in cloned_key_caches: + cloned_key_cache[dst].copy_(cloned_key_cache[src]) + for cloned_value_cache in cloned_value_caches: + cloned_value_cache[dst].copy_(cloned_value_cache[src]) # Compare the results. for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):