@@ -729,23 +729,24 @@ class ParallelInterface(MutableMapping):
729
729
730
730
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
731
731
# a new instance is created (in order to locally override a given function)
732
- _global_mapping = {
733
- "colwise" : ColwiseParallel (),
734
- "rowwise" : RowwiseParallel (),
735
- "colwise_rep" : ColwiseParallel (output_layouts = Replicate ()),
736
- "rowwise_rep" : RowwiseParallel (input_layouts = Replicate ()),
737
- "local_colwise" : ColwiseParallel (use_dtensor = False ),
738
- "local_rowwise" : RowwiseParallel (use_dtensor = False ),
739
- "local" : IsolatedParallel (),
740
- "gather" : GatherParallel (),
741
- "local_packed_rowwise" : PackedRowwiseParallel (use_dtensor = False ),
742
- "sequence_parallel" : SequenceParallel (),
743
- "replicate" : ReplicateParallel (),
744
- }
745
732
746
733
def __init__ (self ):
747
734
self ._local_mapping = {}
748
735
736
+ ParallelInterface ._global_mapping = {
737
+ "colwise" : ColwiseParallel (),
738
+ "rowwise" : RowwiseParallel (),
739
+ "colwise_rep" : ColwiseParallel (output_layouts = Replicate ()),
740
+ "rowwise_rep" : RowwiseParallel (input_layouts = Replicate ()),
741
+ "local_colwise" : ColwiseParallel (use_dtensor = False ),
742
+ "local_rowwise" : RowwiseParallel (use_dtensor = False ),
743
+ "local" : IsolatedParallel (),
744
+ "gather" : GatherParallel (),
745
+ "local_packed_rowwise" : PackedRowwiseParallel (use_dtensor = False ),
746
+ "sequence_parallel" : SequenceParallel (),
747
+ "replicate" : ReplicateParallel (),
748
+ }
749
+
749
750
def __getitem__ (self , key ):
750
751
# First check if instance has a local override
751
752
if key in self ._local_mapping :
@@ -775,7 +776,11 @@ def valid_keys(self) -> List[str]:
775
776
776
777
777
778
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
778
- ALL_PARALLEL_STYLES : ParallelInterface = ParallelInterface ()
779
+
780
+ if is_torch_greater_or_equal ("2.5" ) and _torch_distributed_available :
781
+ ALL_PARALLEL_STYLES : ParallelInterface = ParallelInterface ()
782
+ else :
783
+ ALL_PARALLEL_STYLES = None
779
784
780
785
781
786
def convert_local_tensor_to_dtensor (
0 commit comments