|
7 | 7 | from torch.utils._pytree import tree_map |
8 | 8 |
|
9 | 9 | from colossalai.accelerator import get_accelerator |
10 | | -from colossalai.interface import ModelWrapper, OptimizerWrapper |
| 10 | +from colossalai.interface import OptimizerWrapper |
11 | 11 | from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata |
12 | 12 | from colossalai.pipeline.stage_manager import PipelineStageManager |
13 | 13 | from colossalai.utils import get_current_device |
@@ -327,9 +327,7 @@ def run_forward_only( |
327 | 327 | self.send_forward(output_obj) |
328 | 328 |
|
329 | 329 | if outputs is not None: |
330 | | - if isinstance(model, ModelWrapper): |
331 | | - model = model.unwrap() |
332 | | - outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) |
| 330 | + outputs = merge_batch(outputs) |
333 | 331 | return {"loss": accum_loss, "outputs": outputs} |
334 | 332 |
|
335 | 333 | def run_forward_backward( |
@@ -412,9 +410,7 @@ def run_forward_backward( |
412 | 410 | assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) |
413 | 411 |
|
414 | 412 | if outputs is not None: |
415 | | - if isinstance(model, ModelWrapper): |
416 | | - model = model.unwrap() |
417 | | - outputs = merge_batch(outputs, getattr(model, "batch_size_dim", 0)) |
| 413 | + outputs = merge_batch(outputs) |
418 | 414 | return {"loss": accum_loss, "outputs": outputs} |
419 | 415 |
|
420 | 416 | def forward_backward_step( |
|
0 commit comments