From 54fbdcd3b85cc348e31dc0cd0c7d7506c3b91ba5 Mon Sep 17 00:00:00 2001 From: Matt Date: Fri, 30 Aug 2024 16:56:22 +0100 Subject: [PATCH] Fix local repos with remote code not registering for pipelines (#33100) * Extremely experimental fix! * Try removing the clause entirely * Add test * make fixup * stash commit * Remove breakpoint * Add anti-regression test * make fixup * Move repos to hf-internal-testing! --- src/transformers/models/auto/auto_factory.py | 11 ++--------- tests/models/auto/test_modeling_auto.py | 10 ++++++++++ tests/pipelines/test_pipelines_common.py | 10 +++++++++- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/auto/auto_factory.py b/src/transformers/models/auto/auto_factory.py index 6b572b25277984..220ae97f5073c6 100644 --- a/src/transformers/models/auto/auto_factory.py +++ b/src/transformers/models/auto/auto_factory.py @@ -17,7 +17,6 @@ import copy import importlib import json -import os import warnings from collections import OrderedDict @@ -427,10 +426,7 @@ def from_config(cls, config, **kwargs): else: repo_id = config.name_or_path model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) - if os.path.isdir(config._name_or_path): - model_class.register_for_auto_class(cls.__name__) - else: - cls.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) _ = kwargs.pop("code_revision", None) return model_class._from_config(config, **kwargs) elif type(config) in cls._model_mapping.keys(): @@ -552,10 +548,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs ) _ = hub_kwargs.pop("code_revision", None) - if os.path.isdir(pretrained_model_name_or_path): - model_class.register_for_auto_class(cls.__name__) - else: - cls.register(config.__class__, model_class, exist_ok=True) + cls.register(config.__class__, model_class, exist_ok=True) return model_class.from_pretrained( pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs ) diff --git a/tests/models/auto/test_modeling_auto.py b/tests/models/auto/test_modeling_auto.py index 363028c7f22978..61085b9b5d4572 100644 --- a/tests/models/auto/test_modeling_auto.py +++ b/tests/models/auto/test_modeling_auto.py @@ -21,6 +21,7 @@ from pathlib import Path import pytest +from huggingface_hub import Repository import transformers from transformers import BertConfig, GPT2Model, is_safetensors_available, is_torch_available @@ -529,3 +530,12 @@ def test_attr_not_existing(self): _MODEL_MAPPING_NAMES = OrderedDict([("bert", "GPT2Model")]) _MODEL_MAPPING = _LazyAutoMapping(_CONFIG_MAPPING_NAMES, _MODEL_MAPPING_NAMES) self.assertEqual(_MODEL_MAPPING[BertConfig], GPT2Model) + + def test_dynamic_saving_from_local_repo(self): + with tempfile.TemporaryDirectory() as tmp_dir, tempfile.TemporaryDirectory() as tmp_dir_out: + _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-custom-architecture") + model = AutoModelForCausalLM.from_pretrained(tmp_dir, trust_remote_code=True) + model.save_pretrained(tmp_dir_out) + _ = AutoModelForCausalLM.from_pretrained(tmp_dir_out, trust_remote_code=True) + self.assertTrue((Path(tmp_dir_out) / "modeling_fake_custom.py").is_file()) + self.assertTrue((Path(tmp_dir_out) / "configuration_fake_custom.py").is_file()) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f4aa1a27f505d7..e99af89f7c0820 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -22,7 +22,7 @@ import datasets import numpy as np -from huggingface_hub import HfFolder, delete_repo +from huggingface_hub import HfFolder, Repository, delete_repo from requests.exceptions import HTTPError from transformers import ( @@ -226,6 +226,14 @@ def test_torch_dtype_property(self): pipe.model = None self.assertIsNone(pipe.torch_dtype) + @require_torch + def test_auto_model_pipeline_registration_from_local_dir(self): + with tempfile.TemporaryDirectory() as tmp_dir: + _ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-custom-architecture") + pipe = pipeline("text-generation", tmp_dir, trust_remote_code=True) + + self.assertIsInstance(pipe, TextGenerationPipeline) # Assert successful load + @is_pipeline_test class PipelineScikitCompatTest(unittest.TestCase):