Skip to content

Commit f77b986

Browse files
SunMarcArthurZucker
authored andcommitted
fix multi-gpu with static cache (#32543)
1 parent 7e85948 commit f77b986

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/cache_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,8 @@ def update(
10671067
A tuple containing the updated key and value states.
10681068
"""
10691069
cache_position = cache_kwargs.get("cache_position")
1070+
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device)
1071+
self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device)
10701072
k_out = self.key_cache[layer_idx]
10711073
v_out = self.value_cache[layer_idx]
10721074

@@ -1078,8 +1080,6 @@ def update(
10781080
# `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
10791081
# operation, that avoids copies and uses less memory.
10801082
try:
1081-
# If using several devices (e.g.: multiple GPUs), we need to ensure everything is on the same one
1082-
cache_position.to(device=k_out.device)
10831083
k_out.index_copy_(2, cache_position, key_states)
10841084
v_out.index_copy_(2, cache_position, value_states)
10851085
except NotImplementedError:

0 commit comments

Comments
 (0)