Skip to content

Commit db8f6bc

Browse files
committed
revert: enable return_outputs when necessary
1 parent e6740ea commit db8f6bc

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

examples/images/vit/vit_train_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def run_forward_backward(
3737
if isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1:
3838
# run pipeline forward backward when enabling pp in hybrid parallel plugin
3939
output_dict = booster.execute_pipeline(
40-
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=False
40+
data_iter, model, criterion, optimizer, return_loss=True, return_outputs=True
4141
)
4242
loss, outputs = output_dict["loss"], output_dict["outputs"]
4343
else:

examples/language/bert/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def evaluate_subset(dataloader: DataLoader):
7070
current_rank = dist.get_rank()
7171
batch = iter([batch])
7272

73-
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=False)
73+
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
7474

7575
if is_pp_last_device:
7676
logits = outputs["outputs"]["logits"]

examples/language/gpt/hybridparallelism/finetune.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def evaluate_subset(dataloader: DataLoader):
6464
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
6565
current_rank = dist.get_rank()
6666
batch = iter([batch])
67-
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=False)
67+
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
6868

6969
if is_pp_last_stage:
7070
logits = outputs["outputs"]["logits"]

tests/test_shardformer/test_model/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def _criterion(outputs, inputs):
173173

174174
data_iter = iter([data])
175175
sharded_output = booster.execute_pipeline(
176-
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=False
176+
data_iter, sharded_model, _criterion, sharded_optimizer, return_loss=True, return_outputs=True
177177
)
178178
sharded_loss = sharded_output["loss"]
179179
else:

0 commit comments

Comments
 (0)