|
8 | 8 | from flagai.model.layers.embeddings import VocabParallelEmbedding |
9 | 9 | from flagai.model.utils import normal_init_method |
10 | 10 | from flagai.model.base_model import BaseModel |
| 11 | +import torch.nn.functional as F |
11 | 12 |
|
12 | 13 | if os.getenv('ENV_TYPE') == 'deepspeed+mpu': |
13 | 14 | from flagai.mpu import get_model_parallel_world_size |
14 | | - from flagai.mpu import gather_from_model_parallel_region |
15 | 15 | from flagai.mpu import get_cuda_rng_tracker |
16 | 16 | from flagai.mpu.utils import divide |
17 | 17 | if os.getenv('ENV_TYPE') == 'deepspeed+mpu': |
18 | | - from flagai.mpu import copy_to_model_parallel_region |
19 | 18 | from flagai.mpu.random import checkpoint |
| 19 | + from flagai.mpu import copy_to_model_parallel_region, gather_from_model_parallel_region |
| 20 | + from flagai.mpu.cross_entropy import vocab_parallel_cross_entropy |
| 21 | + |
20 | 22 | elif os.getenv('ENV_TYPE') == 'deepspeed': |
21 | 23 | from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint |
22 | 24 | else: |
@@ -224,6 +226,7 @@ def __init__(self, config, **kwargs): |
224 | 226 | do_layer_norm_before=self.config.get("do_layer_norm_before", True), |
225 | 227 | ) |
226 | 228 | self.config = config_gpt |
| 229 | + self.parallel_output = True |
227 | 230 |
|
228 | 231 | self.transformer = GPT2Stack(self.config) |
229 | 232 | self.lm_head = nn.Linear(self.config.n_embd, |
@@ -266,21 +269,70 @@ def forward( |
266 | 269 | output_attentions=output_attentions, |
267 | 270 | output_hidden_states=output_hidden_states, |
268 | 271 | ) |
269 | | - hidden_states = transformer_outputs |
| 272 | + logits = transformer_outputs |
| 273 | + |
| 274 | + if os.getenv("ENV_TYPE") == 'deepspeed+mpu': |
| 275 | + logits_parallel = copy_to_model_parallel_region(logits) |
| 276 | + else: |
| 277 | + logits_parallel = logits |
270 | 278 |
|
271 | | - lm_logits = self.lm_head(hidden_states) |
| 279 | + # if self.output_predict: |
| 280 | + # Parallel logits. |
| 281 | + logits_parallel = F.linear(logits_parallel, |
| 282 | + self.transformer.wte.weight) |
272 | 283 |
|
273 | | - return_data = {"logits": lm_logits} |
274 | 284 | if labels is not None: |
275 | | - # Shift so that tokens < n predict n |
276 | | - shift_logits = lm_logits[..., :-1, :].contiguous() |
| 285 | + shift_logits = logits_parallel[..., :-1, :].contiguous() |
277 | 286 | shift_labels = labels[..., 1:].contiguous() |
278 | | - loss_fct = nn.CrossEntropyLoss() |
279 | | - loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
280 | | - shift_labels.view(-1)) |
281 | | - return_data["loss"] = loss |
282 | 287 |
|
283 | | - return return_data |
| 288 | + if os.getenv("ENV_TYPE") == 'deepspeed+mpu': |
| 289 | + loss = vocab_parallel_cross_entropy( |
| 290 | + shift_logits.contiguous().float(), shift_labels).mean() |
| 291 | + else: |
| 292 | + loss = F.cross_entropy( |
| 293 | + shift_logits.contiguous().float(), shift_labels.long()) |
| 294 | + |
| 295 | + if self.parallel_output: # Put in different GPUs |
| 296 | + return { |
| 297 | + 'logits': logits_parallel, |
| 298 | + 'loss': loss, |
| 299 | + 'hidden_states': None, |
| 300 | + } |
| 301 | + else: |
| 302 | + return { |
| 303 | + "logits": |
| 304 | + gather_from_model_parallel_region(logits_parallel), |
| 305 | + "loss": |
| 306 | + loss, |
| 307 | + "hidden_states": |
| 308 | + None, |
| 309 | + } |
| 310 | + else: |
| 311 | + if self.parallel_output: # Put in different GPUs |
| 312 | + return { |
| 313 | + 'logits': logits_parallel, |
| 314 | + 'hidden_states': None, |
| 315 | + } |
| 316 | + else: |
| 317 | + return { |
| 318 | + "logits": |
| 319 | + gather_from_model_parallel_region(logits_parallel), |
| 320 | + "hidden_states": |
| 321 | + None, |
| 322 | + } |
| 323 | + |
| 324 | + # lm_logits = self.lm_head(hidden_states) |
| 325 | + # return_data = {"logits": lm_logits} |
| 326 | + # if labels is not None: |
| 327 | + # # Shift so that tokens < n predict n |
| 328 | + # shift_logits = lm_logits[..., :-1, :].contiguous() |
| 329 | + # shift_labels = labels[..., 1:].contiguous() |
| 330 | + # loss_fct = nn.CrossEntropyLoss() |
| 331 | + # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
| 332 | + # shift_labels.view(-1)) |
| 333 | + # return_data["loss"] = loss |
| 334 | + |
| 335 | + # return return_data |
284 | 336 |
|
285 | 337 | def load_weights(self, checkpoint_path): |
286 | 338 | checkpoint = torch.load(checkpoint_path, |
|
0 commit comments