Skip to content

Commit b5f4946

Browse files
committed
Protect ParallelInterface
1 parent 113424b commit b5f4946

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@
125125
"jaxlib>=0.4.1,<=0.4.13",
126126
"jieba",
127127
"jinja2>=3.1.0",
128-
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
128+
"kenlm",
129129
# Keras pin - this is to make sure Keras 3 doesn't destroy us. Remove or change when we have proper support.
130130
"keras>2.9,<2.16",
131131
"keras-nlp>=0.3.1,<0.14.0", # keras-nlp 0.14 doesn't support keras 2, see pin on keras.
@@ -315,7 +315,7 @@ def run(self):
315315
"librosa",
316316
"pyctcdecode",
317317
"phonemizer",
318-
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
318+
"kenlm",
319319
)
320320
# `pip install ".[speech]"` is deprecated and `pip install ".[torch-speech]"` should be used instead
321321
extras["speech"] = deps_list("torchaudio") + extras["audio"]

src/transformers/dependency_versions_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"jaxlib": "jaxlib>=0.4.1,<=0.4.13",
3333
"jieba": "jieba",
3434
"jinja2": "jinja2>=3.1.0",
35-
"kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5": "kenlm@git+https://github.com/ydshieh/kenlm@78f664fb3dafe1468d868d71faf19534530698d5",
35+
"kenlm": "kenlm",
3636
"keras": "keras>2.9,<2.16",
3737
"keras-nlp": "keras-nlp>=0.3.1,<0.14.0",
3838
"kernels": "kernels>=0.4.4,<0.5",

src/transformers/integrations/tensor_parallel.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -729,23 +729,24 @@ class ParallelInterface(MutableMapping):
729729

730730
# Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
731731
# a new instance is created (in order to locally override a given function)
732-
_global_mapping = {
733-
"colwise": ColwiseParallel(),
734-
"rowwise": RowwiseParallel(),
735-
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
736-
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
737-
"local_colwise": ColwiseParallel(use_dtensor=False),
738-
"local_rowwise": RowwiseParallel(use_dtensor=False),
739-
"local": IsolatedParallel(),
740-
"gather": GatherParallel(),
741-
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
742-
"sequence_parallel": SequenceParallel(),
743-
"replicate": ReplicateParallel(),
744-
}
745732

746733
def __init__(self):
747734
self._local_mapping = {}
748735

736+
ParallelInterface._global_mapping = {
737+
"colwise": ColwiseParallel(),
738+
"rowwise": RowwiseParallel(),
739+
"colwise_rep": ColwiseParallel(output_layouts=Replicate()),
740+
"rowwise_rep": RowwiseParallel(input_layouts=Replicate()),
741+
"local_colwise": ColwiseParallel(use_dtensor=False),
742+
"local_rowwise": RowwiseParallel(use_dtensor=False),
743+
"local": IsolatedParallel(),
744+
"gather": GatherParallel(),
745+
"local_packed_rowwise": PackedRowwiseParallel(use_dtensor=False),
746+
"sequence_parallel": SequenceParallel(),
747+
"replicate": ReplicateParallel(),
748+
}
749+
749750
def __getitem__(self, key):
750751
# First check if instance has a local override
751752
if key in self._local_mapping:
@@ -775,7 +776,11 @@ def valid_keys(self) -> List[str]:
775776

776777

777778
# Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
778-
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
779+
780+
if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
781+
ALL_PARALLEL_STYLES: ParallelInterface = ParallelInterface()
782+
else:
783+
ALL_PARALLEL_STYLES = None
779784

780785

781786
def convert_local_tensor_to_dtensor(

0 commit comments

Comments
 (0)