Skip to content

Commit 2ea0b17

Browse files
committed
style: polish code
1 parent c4deb7a commit 2ea0b17

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

colossalai/shardformer/policies/base_policy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,4 +242,4 @@ def get_stage_index(
242242
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
243243
stage_indices.append([start_idx, end_idx])
244244

245-
return stage_indices[0] if num_model_chunks == 1 else stage_indices
245+
return stage_indices[0] if num_model_chunks == 1 else stage_indices

tests/test_shardformer/test_layer/test_dist_crossentropy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
3838
org_loss, dist_loss, atol=1e-5
3939
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
4040

41-
4241
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
43-
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
42+
assert torch.allclose(
43+
target_grad, dist_pred.grad
44+
), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
4445

4546

4647
@pytest.mark.dist

tests/test_shardformer/test_model/test_shard_gptj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def check_gptj_3d(rank, world_size, port):
207207
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
208208
run_gptj_3d_test()
209209

210+
210211
@pytest.mark.skip("TODO check_gptj has something wrong.")
211212
@pytest.mark.dist
212213
@rerun_if_address_is_in_use()

0 commit comments

Comments
 (0)