Skip to content

Commit

Permalink
make fixu
Browse files Browse the repository at this point in the history
  • Loading branch information
gante committed Jul 23, 2024
1 parent 48ed251 commit c824be0
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@
)
_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "rope_config_validation"]
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"]

# PyTorch models structure
Expand Down Expand Up @@ -6011,7 +6011,7 @@
WatermarkLogitsProcessor,
WhisperTimeStampLogitsProcessor,
)
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS, rope_config_validation
from .modeling_rope_utils import ROPE_INIT_FUNCTIONS
from .modeling_utils import PreTrainedModel
from .models.albert import (
AlbertForMaskedLM,
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,10 +488,6 @@ def __init__(self, *args, **kwargs):
ROPE_INIT_FUNCTIONS = None


def rope_config_validation(*args, **kwargs):
requires_backends(rope_config_validation, ["torch"])


class PreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
8 changes: 6 additions & 2 deletions tests/utils/test_modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from transformers import LlamaConfig
from transformers.testing_utils import is_torch_available, require_torch, torch_device


if is_torch_available():
import torch
from transformers import ROPE_INIT_FUNCTIONS, rope_config_validation

from transformers import ROPE_INIT_FUNCTIONS
from transformers.modeling_rope_utils import rope_config_validation


@require_torch
Expand Down Expand Up @@ -113,4 +116,5 @@ def test_dynamic_rope_function_bc(self):
kwargs_freqs = rope_fn(**rope_kwargs, device=device)[0]
torch.testing.assert_close(config_freqs, kwargs_freqs)

#TODO(joao): numerical checks for the different RoPE fns

# TODO(joao): numerical checks for the different RoPE fns

0 comments on commit c824be0

Please sign in to comment.