|
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | from .configuration_utils import PretrainedConfig |
10 | | -from .utils import is_hqq_available, is_quanto_available, logging |
| 10 | +from .utils import is_hqq_available, is_quanto_available, logging, is_torch_xla_available |
11 | 11 |
|
12 | 12 | if is_quanto_available(): |
13 | 13 | from quanto import QBitsTensor, qint2, qint4 |
@@ -791,6 +791,23 @@ def update( |
791 | 791 | k_out = self.key_cache[layer_idx] |
792 | 792 | v_out = self.value_cache[layer_idx] |
793 | 793 |
|
| 794 | + if is_torch_xla_available(): # If torch_xla is available, do out-of-place operation on KV_Cache and create a new list |
| 795 | + k_out = k_out.index_copy(2, cache_position, key_states) |
| 796 | + v_out = v_out.index_copy(2, cache_position, value_states) |
| 797 | + |
| 798 | + updated_key_cache = [ |
| 799 | + k_out if i == layer_idx else self.key_cache[i] for i in range(len(self.key_cache)) |
| 800 | + ] |
| 801 | + |
| 802 | + updated_value_cache = [ |
| 803 | + v_out if i == layer_idx else self.value_cache[i] for i in range(len(self.value_cache)) |
| 804 | + ] |
| 805 | + |
| 806 | + self.key_cache = updated_key_cache |
| 807 | + self.value_cache = updated_value_cache |
| 808 | + |
| 809 | + return k_out, v_out |
| 810 | + |
794 | 811 | k_out.index_copy_(2, cache_position, key_states) |
795 | 812 | v_out.index_copy_(2, cache_position, value_states) |
796 | 813 |
|
|
0 commit comments