-
Notifications
You must be signed in to change notification settings - Fork 3k
/
Copy pathmodeling_pp.py
469 lines (397 loc) Β· 18.3 KB
/
modeling_pp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import (
LayerDesc,
PipelineLayer,
SharedLayerDesc,
)
from paddle.distributed.fleet.recompute.recompute import recompute
from paddlenlp.transformers.model_utils import PipelinePretrainedModel
from paddlenlp.transformers.refined_recompute import get_skip_recompute_ops
from paddlenlp.transformers.refined_recompute import recompute as rr_recompute
from paddlenlp.utils.tools import get_env_device
from ..dpo_criterion import DPOCriterion
from .modeling import (
LlamaConfig,
LlamaDecoderLayer,
LlamaLMHead,
LlamaModel,
LlamaPretrainedModel,
LlamaPretrainingCriterion,
LlamaRMSNorm,
build_alibi_tensor,
)
def __repr__(self):
return self.layer_func.__name__
# hack LayerDesc for showing to much config
LayerDesc.__repr__ = __repr__
__all__ = [
"LlamaForCausalLMPipe",
]
def parse_args(args):
if isinstance(args, tuple):
if len(args) == 5:
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = args
elif len(args) == 4:
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args
alibi = None
elif len(args) == 3:
hidden_states, attention_mask, attn_mask_startend_row_indices = args
position_ids = None
alibi = None
elif len(args) == 2:
hidden_states, attention_mask = args
attn_mask_startend_row_indices = None
position_ids = None
alibi = None
else:
hidden_states = args
attention_mask, attn_mask_startend_row_indices, position_ids, alibi = None, None, None, None
if position_ids is not None:
position_ids.stop_gradient = True
if attention_mask is not None:
attention_mask.stop_gradient = True
if attn_mask_startend_row_indices is not None:
attn_mask_startend_row_indices.stop_gradient = True
if alibi is not None:
alibi.stop_gradient = True
return hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi
def return_args(
hidden_states, attention_mask=None, attn_mask_startend_row_indices=None, position_ids=None, alibi=None
):
ret = (hidden_states,)
if attention_mask is not None:
ret += (attention_mask.clone(),)
if attn_mask_startend_row_indices is not None:
ret += (attn_mask_startend_row_indices.clone(),)
if position_ids is not None:
ret += (position_ids.clone(),)
if alibi is not None:
ret += (alibi.clone(),)
if len(ret) == 1:
ret = ret[0]
return ret
def get_attr(layer, name):
if getattr(layer, name, None) is not None:
return getattr(layer, name, None)
else:
return get_attr(layer._layer, name)
class LlamaEmbeddingPipe(nn.Layer):
"""Extends LlamaEmbeddings to forward attention_mask through the pipeline."""
def __init__(self, config):
super(LlamaEmbeddingPipe, self).__init__()
self.config = config
self.sequence_parallel = config.sequence_parallel
self.hidden_size = config.hidden_size
if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0:
self.embed_tokens = fleet.meta_parallel.VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()),
)
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
@property
def embedding_weight(self):
return get_attr(self.embed_tokens, "weight")
def forward(self, args):
"""_summary_
Args:
input (_type_): _description_
Returns:
_type_: _description_
"""
input_ids, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)
# we can't distinguish
if self.config.alibi and alibi is None and position_ids is None and attn_mask_startend_row_indices is not None:
# input_ids, attention_mask, alibi
alibi = attn_mask_startend_row_indices
position_ids = None
attn_mask_startend_row_indices = None
elif (
self.config.alibi
and alibi is None
and position_ids is not None
and attn_mask_startend_row_indices is not None
):
# input_ids, attention_mask, position_ids, alibi
alibi = position_ids
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
elif not self.config.alibi:
if get_env_device() in ["gpu"]:
if attention_mask is not None and attention_mask.dtype == paddle.int32:
attention_mask, attn_mask_startend_row_indices, position_ids = (
None,
attention_mask,
attn_mask_startend_row_indices,
)
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
elif (
attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64
):
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
elif position_ids is None and attn_mask_startend_row_indices is not None:
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
input_embeds = self.embed_tokens(input_ids)
if self.sequence_parallel:
from paddlenlp.transformers import ScatterOp
# [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim]
bs, seq_len, hidden_size = input_embeds.shape
input_embeds = paddle.reshape_(input_embeds, [bs * seq_len, hidden_size])
# [seq_len * bs / n, num_head * head_dim] (n is mp parallelism)
input_embeds = ScatterOp.apply(input_embeds)
batch_size, seq_length = input_ids.shape
alibi = None
if self.config.alibi:
assert (
attn_mask_startend_row_indices is None
), "alibi and attn_mask_startend_row_indices can not be set at same time"
# embed positions
mask = (
attention_mask
if attention_mask is not None
else paddle.ones((batch_size, seq_length), dtype=paddle.bool)
)
alibi = build_alibi_tensor(mask, self.config.num_attention_heads, dtype=input_embeds.dtype)
if self.config.tensor_parallel_degree > 1:
block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree
alibi = alibi[
:,
self.config.tensor_parallel_rank
* block_size : (self.config.tensor_parallel_rank + 1)
* block_size,
]
alibi = alibi.reshape([batch_size * block_size, 1, seq_length])
else:
alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length])
alibi.stop_gradient = True
if attention_mask is not None:
assert (
attn_mask_startend_row_indices is None
), "attention_mask and attn_mask_startend_row_indices can not be set at same time"
attention_mask = LlamaModel._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), 0, input_embeds.dtype
)
attention_mask.stop_gradient = True
if get_env_device() == "npu":
attention_mask = attention_mask.astype("bool")
elif get_env_device() == "npu":
attention_mask = paddle.tril(paddle.ones((seq_length, seq_length), dtype="bool"))
attention_mask.stop_gradient = True
if self.config.alibi and attention_mask is None:
attention_mask = LlamaModel._prepare_decoder_attention_mask(
None, (batch_size, seq_length), 0, input_embeds.dtype
)
attention_mask.stop_gradient = True
return return_args(input_embeds, attention_mask, attn_mask_startend_row_indices, position_ids, alibi)
class LlamaDecoderLayerPipe(LlamaDecoderLayer):
def forward(self, args):
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)
# we can't distinguish
if self.config.alibi and alibi is None and position_ids is None and attn_mask_startend_row_indices is not None:
# hidden_states, attention_mask, alibi
alibi = attn_mask_startend_row_indices
position_ids = None
attn_mask_startend_row_indices = None
elif (
self.config.alibi
and alibi is None
and position_ids is not None
and attn_mask_startend_row_indices is not None
):
# hidden_states, attention_mask, position_ids, alibi
alibi = position_ids
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
elif not self.config.alibi:
if get_env_device() in ["gpu"]:
if attention_mask is not None and attention_mask.dtype == paddle.int32:
attention_mask, attn_mask_startend_row_indices, position_ids = (
None,
attention_mask,
attn_mask_startend_row_indices,
)
elif attention_mask is not None and attention_mask.dtype == paddle.int64:
attention_mask, attn_mask_startend_row_indices, position_ids = None, None, attention_mask
elif (
attn_mask_startend_row_indices is not None and attn_mask_startend_row_indices.dtype == paddle.int64
):
attn_mask_startend_row_indices, position_ids = None, attn_mask_startend_row_indices
elif position_ids is None and attn_mask_startend_row_indices is not None:
position_ids = attn_mask_startend_row_indices
attn_mask_startend_row_indices = None
has_gradient = not hidden_states.stop_gradient
if (
self.enable_recompute
and self.layerwise_recompute
and self.config.recompute_granularity == "full"
and has_gradient
):
recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute
if attention_mask is not None or alibi is not None or attn_mask_startend_row_indices is not None:
hidden_states = recompute_fn(
super().forward,
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=False,
)
else:
# for pretrain
hidden_states = recompute_fn(
super().forward,
hidden_states,
position_ids=position_ids,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
use_reentrant=self.config.recompute_use_reentrant,
)
else:
hidden_states = super().forward(
hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
alibi=alibi,
attn_mask_startend_row_indices=attn_mask_startend_row_indices,
)
return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi)
class LlamaRMSNormPipe(nn.Layer):
def __init__(self, config):
super().__init__()
self.norm = LlamaRMSNorm(config)
def forward(self, args):
hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids, alibi = parse_args(args)
return self.norm(hidden_states)
class LlamaLMHeadPipe(LlamaLMHead):
def __init__(self, config, transpose_y=False):
super(LlamaLMHeadPipe, self).__init__(config, transpose_y=transpose_y)
@property
def embedding_weight(self):
return get_attr(self, "weight")
class LlamaForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
"""LlamaForPretraining adapted for pipeline parallelism.
The largest change is flattening the LlamaModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
"""
config_class = LlamaConfig
_get_tensor_parallel_mappings = LlamaPretrainedModel._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = LlamaPretrainedModel._get_fuse_or_split_param_mappings
_init_weights = LlamaPretrainedModel._init_weights
_keys_to_ignore_on_load_unexpected = LlamaPretrainedModel._keys_to_ignore_on_load_unexpected
_get_model_flops = LlamaPretrainedModel._get_model_flops
_get_hardware_flops = LlamaPretrainedModel._get_hardware_flops
_tied_weights_keys = ["lm_head.weight"]
# DONOT Add base_model_prefix !!!!
@classmethod
def _prepare_pipeline_inputs_func(cls, inputs):
first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"]
last_stage_keys = ["labels"]
def get_expected_keys(inputs, keys):
ret = tuple([inputs.pop(k) if k in inputs else None for k in keys])
if len(ret) == 1:
ret = ret[0]
return ret
if type(inputs) is dict or type(inputs) is OrderedDict:
return [
get_expected_keys(inputs, first_stage_keys),
get_expected_keys(inputs, last_stage_keys),
]
keys = list(inputs[0].keys())
inputs_batch = {key: [data.pop(key) for data in inputs] for key in keys}
return [
get_expected_keys(inputs_batch, first_stage_keys),
get_expected_keys(inputs_batch, last_stage_keys),
]
def __init__(self, config):
self.config = config
self.recompute_granularity = self.config.recompute_granularity
self.pp_recompute_interval = self.config.pp_recompute_interval
self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else []
virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1)
def get_hcg():
return fleet.get_hybrid_communicate_group()
hcg = get_hcg()
tensor_parallel_degree = max(hcg.get_model_parallel_world_size(), 1)
tensor_parallel_rank = max(hcg.get_model_parallel_rank(), 0)
# TODO: fix tensor_parallel_degree rewrite in here
config.tensor_parallel_degree = tensor_parallel_degree
config.tensor_parallel_rank = tensor_parallel_rank
if config.tie_word_embeddings:
self.add_sequential_layer(
SharedLayerDesc(
"llama_shared_weight", LlamaEmbeddingPipe, shared_weight_attr="embedding_weight", config=config
),
"llama",
)
else:
self.add_sequential_layer(LayerDesc(LlamaEmbeddingPipe, config=config), "llama")
for i in range(config.num_hidden_layers):
self.add_sequential_layer(
LayerDesc(
LlamaDecoderLayerPipe,
config=config,
layerwise_recompute=i not in self.no_recompute_layers,
skip_recompute_ops=get_skip_recompute_ops(config, i),
),
f"llama.layers.{i}",
)
self.add_sequential_layer(LayerDesc(LlamaRMSNormPipe, config=config), "llama")
if config.tie_word_embeddings:
self.add_sequential_layer(
SharedLayerDesc(
"llama_shared_weight",
LlamaLMHeadPipe,
shared_weight_attr="embedding_weight",
config=config,
**{"transpose_y": True},
),
"lm_head",
)
else:
self.add_sequential_layer(LayerDesc(LlamaLMHeadPipe, config=config), "lm_head")
recompute_interval = 0
seg_method = "layer:LlamaDecoderLayer"
if config.num_hidden_layers % get_hcg().topology().get_dim_size("pipe") != 0:
seg_method = "uniform"
PipelineLayer.__init__(
self,
layers=self.get_sequential_layers(),
loss_fn=self.get_loss_fn(config),
topology=get_hcg().topology(),
seg_method=seg_method,
recompute_interval=recompute_interval,
recompute_ctx={
"mp_group": get_hcg().get_model_parallel_group(),
"offload": False,
"partition": False,
},
num_virtual_pipeline_stages=virtual_pp_degree,
)
# You should call init here, since there is a diamond inheritance problem
self.apply(self._init_weights)
# DON'T init PipelinePretrainedModel
# PipelinePretrainedModel.__init__(self.super(), config=config)
def get_loss_fn(self, config):
if config.dpo_config is not None:
return DPOCriterion(config, use_infohub=True)
else:
return LlamaPretrainingCriterion(config)