Skip to content

Commit dbbe02f

Browse files
committed
linter fixes
1 parent 52577a6 commit dbbe02f

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

MaxText/pyconfig.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,20 @@ def pipeline_first_axis(raw_keys):
601601
raw_keys["dcn_expert_parallelism"],
602602
raw_keys["dcn_autoregressive_parallelism"],
603603
]
604-
mesh_axes = ["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "expert", "autoregressive"]
605-
data_sharding = [["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "expert", "autoregressive"]]
604+
mesh_axes = [
605+
"stage",
606+
"data",
607+
"fsdp",
608+
"fsdp_transpose",
609+
"sequence",
610+
"tensor",
611+
"tensor_sequence",
612+
"expert",
613+
"autoregressive",
614+
]
615+
data_sharding = [
616+
["stage", "data", "fsdp", "fsdp_transpose", "sequence", "tensor", "tensor_sequence", "expert", "autoregressive"]
617+
]
606618

607619
raw_keys["ici_parallelism"] = ici_parallelism
608620
raw_keys["dcn_parallelism"] = dcn_parallelism
@@ -651,7 +663,12 @@ def validate_megablox_parallelism(raw_keys):
651663
using_sequence_parallelism(raw_keys) or using_pipeline_parallelism(raw_keys) or using_expert_parallelism(raw_keys)
652664
):
653665
raise ValueError("Currently we only support Megablox with data and tensor parallelism.")
654-
tensor_parallelism = raw_keys["ici_tensor_parallelism"] * raw_keys["dcn_tensor_parallelism"] * raw_keys["ici_tensor_sequence_parallelism"] * raw_keys["dcn_tensor_sequence_parallelism"]
666+
tensor_parallelism = (
667+
raw_keys["ici_tensor_parallelism"]
668+
* raw_keys["dcn_tensor_parallelism"]
669+
* raw_keys["ici_tensor_sequence_parallelism"]
670+
* raw_keys["dcn_tensor_sequence_parallelism"]
671+
)
655672
if raw_keys["megablox"] and using_tensor_parallelism(raw_keys) and (raw_keys["emb_dim"] % tensor_parallelism):
656673
raise ValueError(
657674
f"The embedding dimension {raw_keys['emb_dim']} is not divisible by tensor parallelism setting {tensor_parallelism}."
@@ -769,7 +786,12 @@ def using_pipeline_parallelism(raw_keys) -> bool:
769786

770787

771788
def using_tensor_parallelism(raw_keys) -> bool:
772-
return int(raw_keys["ici_tensor_parallelism"]) > 1 or int(raw_keys["dcn_tensor_parallelism"]) > 1 or int(raw_keys["ici_tensor_sequence_parallelism"]) > 1 or int(raw_keys["dcn_tensor_sequence_parallelism"]) > 1
789+
return (
790+
int(raw_keys["ici_tensor_parallelism"]) > 1
791+
or int(raw_keys["dcn_tensor_parallelism"]) > 1
792+
or int(raw_keys["ici_tensor_sequence_parallelism"]) > 1
793+
or int(raw_keys["dcn_tensor_sequence_parallelism"]) > 1
794+
)
773795

774796

775797
def using_sequence_parallelism(raw_keys) -> bool:

0 commit comments

Comments
 (0)