Skip to content

Commit 6091953

Browse files
change strategy: overriding huggingface_from_pretrained
1 parent d295416 commit 6091953

File tree

2 files changed

+83
-5
lines changed

2 files changed

+83
-5
lines changed
Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,70 @@
1-
from spacy_transformers.layers.transformer_model import TransformerModel
1+
from pathlib import Path
2+
import sys
3+
from typing import Union, Dict
24

5+
from transformers import AutoConfig, AutoModel, AutoTokenizer
36

4-
class TransformerModelCustom(TransformerModel):
5-
pass
7+
from spacy_transformers.layers import transformer_model
8+
from spacy_transformers.data_classes import HFObjects
9+
10+
from thinc.api import get_current_ops, CupyOps
11+
12+
13+
def override_huggingface_from_pretrained():
14+
assert transformer_model.huggingface_from_pretrained is not huggingface_from_pretrained_custom
15+
origin = transformer_model.huggingface_from_pretrained
16+
transformer_model.huggingface_from_pretrained = huggingface_from_pretrained_custom
17+
return origin
18+
19+
def recover_huggingface_from_pretrained(origin):
20+
assert transformer_model.huggingface_from_pretrained is huggingface_from_pretrained_custom
21+
transformer_model.huggingface_from_pretrained = origin
22+
23+
24+
def huggingface_from_pretrained_custom(
25+
source: Union[Path, str], tok_config: Dict, trf_config: Dict
26+
) -> HFObjects:
27+
"""Create a Huggingface transformer model from pretrained weights. Will
28+
download the model if it is not already downloaded.
29+
30+
source (Union[str, Path]): The name of the model or a path to it, such as
31+
'bert-base-cased'.
32+
tok_config (dict): Settings to pass to the tokenizer.
33+
trf_config (dict): Settings to pass to the transformer.
34+
"""
35+
if hasattr(source, "absolute"):
36+
str_path = str(source.absolute())
37+
else:
38+
str_path = source
39+
40+
try:
41+
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
42+
except ValueError as e:
43+
if "tokenizer_class" not in tok_config:
44+
raise e
45+
tokenizer_class_name = tok_config["tokenizer_class"].split(".")
46+
from importlib import import_module
47+
tokenizer_module = import_module(".".join(tokenizer_class_name[:-1]))
48+
tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1])
49+
tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tok_config)
50+
vocab_file_contents = None
51+
if hasattr(tokenizer, "vocab_file"):
52+
with open(tokenizer.vocab_file, "rb") as fileh:
53+
vocab_file_contents = fileh.read()
54+
55+
try:
56+
trf_config["return_dict"] = True
57+
config = AutoConfig.from_pretrained(str_path, **trf_config)
58+
transformer = AutoModel.from_pretrained(str_path, config=config)
59+
except OSError as e:
60+
try:
61+
transformer = AutoModel.from_pretrained(str_path, local_files_only=True)
62+
except OSError as e2:
63+
model_name = str(source)
64+
print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr)
65+
transformer = AutoModel.from_pretrained(model_name)
66+
print("succeded", file=sys.stderr)
67+
ops = get_current_ops()
68+
if isinstance(ops, CupyOps):
69+
transformer.cuda()
70+
return HFObjects(tokenizer, transformer, vocab_file_contents)

ginza_transformers/pipeline_component.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from pathlib import Path
22
import sys
3-
from typing import List, Callable, Iterable, Union
3+
from typing import List, Callable, Iterable, Optional, Union
44

55
from spacy.language import Language
66
from spacy.tokens import Doc
77
from thinc.api import Model, Config
88

99
from spacy_transformers.data_classes import FullTransformerBatch
1010
from spacy_transformers.pipeline_component import Transformer, DOC_EXT_ATTR
11+
from spacy.training import Example
1112

1213
from .layers.hf_shim_custom import override_hf_shims_from_bytes, recover_hf_shims_from_bytes
14+
from .layers.transformer_model import override_huggingface_from_pretrained, recover_huggingface_from_pretrained
1315

1416

1517
DEFAULT_CONFIG_STR = """
@@ -54,10 +56,21 @@ def make_transformer_custom(
5456

5557
class TransformerCustom(Transformer):
5658

59+
def initialize(
60+
self,
61+
get_examples: Callable[[], Iterable[Example]],
62+
*,
63+
nlp: Optional[Language] = None,
64+
):
65+
origin = override_huggingface_from_pretrained()
66+
try:
67+
super().initialize(get_examples, nlp=nlp)
68+
finally:
69+
recover_huggingface_from_pretrained(origin)
70+
5771
def from_disk(
5872
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
5973
) -> "TransformerCustom":
60-
6174
origin = override_hf_shims_from_bytes()
6275
try:
6376
super().from_disk(path, exclude=exclude)

0 commit comments

Comments
 (0)