@@ -601,8 +601,20 @@ def pipeline_first_axis(raw_keys):
601601 raw_keys ["dcn_expert_parallelism" ],
602602 raw_keys ["dcn_autoregressive_parallelism" ],
603603 ]
604- mesh_axes = ["stage" , "data" , "fsdp" , "fsdp_transpose" , "sequence" , "tensor" , "tensor_sequence" , "expert" , "autoregressive" ]
605- data_sharding = [["stage" , "data" , "fsdp" , "fsdp_transpose" , "sequence" , "tensor" , "tensor_sequence" , "expert" , "autoregressive" ]]
604+ mesh_axes = [
605+ "stage" ,
606+ "data" ,
607+ "fsdp" ,
608+ "fsdp_transpose" ,
609+ "sequence" ,
610+ "tensor" ,
611+ "tensor_sequence" ,
612+ "expert" ,
613+ "autoregressive" ,
614+ ]
615+ data_sharding = [
616+ ["stage" , "data" , "fsdp" , "fsdp_transpose" , "sequence" , "tensor" , "tensor_sequence" , "expert" , "autoregressive" ]
617+ ]
606618
607619 raw_keys ["ici_parallelism" ] = ici_parallelism
608620 raw_keys ["dcn_parallelism" ] = dcn_parallelism
@@ -651,7 +663,12 @@ def validate_megablox_parallelism(raw_keys):
651663 using_sequence_parallelism (raw_keys ) or using_pipeline_parallelism (raw_keys ) or using_expert_parallelism (raw_keys )
652664 ):
653665 raise ValueError ("Currently we only support Megablox with data and tensor parallelism." )
654- tensor_parallelism = raw_keys ["ici_tensor_parallelism" ] * raw_keys ["dcn_tensor_parallelism" ] * raw_keys ["ici_tensor_sequence_parallelism" ] * raw_keys ["dcn_tensor_sequence_parallelism" ]
666+ tensor_parallelism = (
667+ raw_keys ["ici_tensor_parallelism" ]
668+ * raw_keys ["dcn_tensor_parallelism" ]
669+ * raw_keys ["ici_tensor_sequence_parallelism" ]
670+ * raw_keys ["dcn_tensor_sequence_parallelism" ]
671+ )
655672 if raw_keys ["megablox" ] and using_tensor_parallelism (raw_keys ) and (raw_keys ["emb_dim" ] % tensor_parallelism ):
656673 raise ValueError (
657674 f"The embedding dimension { raw_keys ['emb_dim' ]} is not divisible by tensor parallelism setting { tensor_parallelism } ."
@@ -769,7 +786,12 @@ def using_pipeline_parallelism(raw_keys) -> bool:
769786
770787
771788def using_tensor_parallelism (raw_keys ) -> bool :
772- return int (raw_keys ["ici_tensor_parallelism" ]) > 1 or int (raw_keys ["dcn_tensor_parallelism" ]) > 1 or int (raw_keys ["ici_tensor_sequence_parallelism" ]) > 1 or int (raw_keys ["dcn_tensor_sequence_parallelism" ]) > 1
789+ return (
790+ int (raw_keys ["ici_tensor_parallelism" ]) > 1
791+ or int (raw_keys ["dcn_tensor_parallelism" ]) > 1
792+ or int (raw_keys ["ici_tensor_sequence_parallelism" ]) > 1
793+ or int (raw_keys ["dcn_tensor_sequence_parallelism" ]) > 1
794+ )
773795
774796
775797def using_sequence_parallelism (raw_keys ) -> bool :
0 commit comments