forked from mit-han-lab/streaming-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkv_cache.py
119 lines (106 loc) · 3.41 KB
/
kv_cache.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
def slice2d(x, start, end):
return x[:, :, start:end, ...]
def slice3d(x, start, end):
return x[:, :, :, start:end, ...]
def slice1d(x, start, end):
return x[:, start:end, ...]
DIM_TO_SLICE = {
1: slice1d,
2: slice2d,
3: slice3d,
}
class StartRecentKVCache:
def __init__(
self,
start_size=4,
recent_size=512,
k_seq_dim=2,
v_seq_dim=2,
):
print(f"StartRecentKVCache: {start_size}, {recent_size}")
self.start_size = start_size
self.recent_size = recent_size
self.cache_size = start_size + recent_size
self.k_seq_dim = k_seq_dim
self.v_seq_dim = v_seq_dim
self.k_slice = DIM_TO_SLICE[k_seq_dim]
self.v_slice = DIM_TO_SLICE[v_seq_dim]
def __call__(self, past_key_values):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(k, seq_len - self.recent_size, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(v, seq_len - self.recent_size, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]
def evict_for_space(self, past_key_values, num_coming):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
if seq_len + num_coming <= self.cache_size:
return past_key_values
return [
[
torch.cat(
[
self.k_slice(k, 0, self.start_size),
self.k_slice(
k, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, self.start_size),
self.v_slice(
v, seq_len - self.recent_size + num_coming, seq_len
),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]
def evict_range(self, past_key_values, start, end):
if past_key_values is None:
return None
seq_len = past_key_values[0][0].size(self.k_seq_dim)
assert start <= end and end <= seq_len
return [
[
torch.cat(
[
self.k_slice(k, 0, start),
self.k_slice(k, end, seq_len),
],
dim=self.k_seq_dim,
),
torch.cat(
[
self.v_slice(v, 0, start),
self.v_slice(v, end, seq_len),
],
dim=self.v_seq_dim,
),
]
for k, v in past_key_values
]