forked from qwopqwop200/GPTQ-for-LLaMa
-
Notifications
You must be signed in to change notification settings - Fork 59
/
llama_inference_offload.py
330 lines (282 loc) · 13.9 KB
/
llama_inference_offload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
import time
import torch
import torch.nn as nn
from gptq import *
from modelutils import *
from quant import *
from transformers import AutoTokenizer
DEV = torch.device('cuda:0')
import copy
from transformers.models.llama.modeling_llama import LlamaModel,LlamaConfig
from transformers.modeling_outputs import BaseModelOutputWithPast
from typing import List, Optional, Tuple, Union
import time
class Offload_LlamaModel(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx in range(len(self.layers)):
if idx <= (self.preload - 1):
decoder_layer = self.layers[idx]
target_device = decoder_layer.target_device
else:
decoder_layer = self.layers[idx].to(DEV)
target_device = DEV
hidden_states = hidden_states.to(target_device)
attention_mask = attention_mask.to(target_device)
position_ids = position_ids.to(target_device)
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if idx > (self.preload - 1):
self.layers[idx] = decoder_layer.cpu()
del decoder_layer
torch.cuda.empty_cache()
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = hidden_states.to(DEV)
attention_mask = attention_mask.to(DEV)
position_ids = position_ids.to(DEV)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def load_quant(model, checkpoint, wbits, groupsize, pre_layer):
transformers.models.llama.modeling_llama.LlamaModel = Offload_LlamaModel
from transformers import LlamaConfig, LlamaForCausalLM
config = LlamaConfig.from_pretrained(model)
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = LlamaForCausalLM(config)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(model, layers, wbits, groupsize)
print('Loading model ...')
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
model.seqlen = 2048
if (isinstance(pre_layer, int)):
pre_layer = [pre_layer]
last_layer = 0
for index, dev_layer in enumerate(pre_layer):
target = torch.device(f'cuda:{index}')
for i in range(last_layer, dev_layer):
model.model.layers[i].to(target)
model.model.layers[i].target_device = target
last_layer = dev_layer
for i in range(last_layer, len(model.model.layers)):
model.model.layers[i].target_device = torch.device('cpu')
model.model.embed_tokens.to(DEV)
model.model.norm.to(DEV)
model.lm_head.to(DEV)
model.model.preload = last_layer
print('Done.')
return model
if __name__ == '__main__':
import argparse
from datautils import *
parser = argparse.ArgumentParser()
parser.add_argument(
'model', type=str,
help='llama model to load'
)
parser.add_argument(
'--wbits', type=int, default=4, choices=[2, 3, 4, 8],
help='#bits to use for quantization'
)
parser.add_argument(
'--groupsize', type=int, default=-1,
help='Groupsize to use for quantization; default uses full row.'
)
parser.add_argument(
'--load', type=str, default='',
help='Load quantized model.'
)
parser.add_argument(
'--text', type=str,
help='input text'
)
parser.add_argument(
'--min_length', type=int, default=10,
help='The minimum length of the sequence to be generated.'
)
parser.add_argument(
'--max_length', type=int, default=50,
help='The maximum length of the sequence to be generated.'
)
parser.add_argument(
'--top_p', type=float , default=0.95,
help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.'
)
parser.add_argument(
'--temperature', type=float, default=0.8,
help='The value used to module the next token probabilities.'
)
parser.add_argument(
'--pre_layer', type=int, default=50,
help='The number of layers to preload'
)
args = parser.parse_args()
if type(args.load) is not str:
args.load = args.load.as_posix()
model = load_quant(args.model, args.load, args.wbits, args.groupsize, args.pre_layer)
tokenizer = AutoTokenizer.from_pretrained(args.model)
input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV)
with torch.no_grad():
generated_ids = model.generate(
input_ids,
do_sample=True,
min_length=args.min_length,
max_length=args.max_length,
top_p=args.top_p,
temperature=args.temperature,
)
print(tokenizer.decode([el.item() for el in generated_ids[0]]))