File tree 2 files changed +8
-5
lines changed
2 files changed +8
-5
lines changed Original file line number Diff line number Diff line change @@ -37,11 +37,10 @@ def swap_blocks(
37
37
) -> None :
38
38
src_k_cache , src_v_cache = src_kv_cache
39
39
dst_k_cache , dst_v_cache = dst_kv_cache
40
+ src_indices , dst_indices = src_to_dst
41
+ device = dst_k_cache .device
40
42
torch .ops .xla .dynamo_set_buffer_donor_ (dst_k_cache , True )
41
43
torch .ops .xla .dynamo_set_buffer_donor_ (dst_v_cache , True )
42
-
43
- device = dst_k_cache .device
44
- src_indices , dst_indices = src_to_dst
45
44
dst_k_cache [:, dst_indices ] = src_k_cache [:, src_indices ].to (device )
46
45
dst_v_cache [:, dst_indices ] = src_v_cache [:, src_indices ].to (device )
47
46
Original file line number Diff line number Diff line change @@ -156,14 +156,18 @@ def initialize_cache(
156
156
self .tpu_cache = []
157
157
tpu_cache_shape = self .model_runner .attn_backend .get_kv_cache_shape (
158
158
num_gpu_blocks , self .block_size , num_kv_heads , head_size )
159
+ cpu_cache_shape = self .model_runner .attn_backend .get_kv_cache_shape (
160
+ num_cpu_blocks , self .block_size , num_kv_heads , head_size )
159
161
for _ in range (num_layers ):
160
162
tpu_k_cache = torch .zeros (tpu_cache_shape ,
161
163
dtype = dtype ,
162
164
device = self .device )
163
165
tpu_v_cache = torch .zeros_like (tpu_k_cache )
164
166
self .tpu_cache .append ((tpu_k_cache , tpu_v_cache ))
165
- cpu_k_cache = torch .zeros_like (tpu_k_cache , device = "cpu" )
166
- cpu_v_cache = torch .zeros_like (tpu_v_cache , device = "cpu" )
167
+ cpu_k_cache = torch .zeros (cpu_cache_shape ,
168
+ dtype = dtype ,
169
+ device = "cpu" )
170
+ cpu_v_cache = torch .zeros_like (cpu_k_cache )
167
171
self .cpu_cache .append ((cpu_k_cache , cpu_v_cache ))
168
172
self ._warmup_model ()
169
173
You can’t perform that action at this time.
0 commit comments