Skip to content

Commit 4a61ace

Browse files
migrate to spacy-transformers 1.1.2 #3
1 parent ab7db08 commit 4a61ace

File tree

4 files changed

+179
-93
lines changed

4 files changed

+179
-93
lines changed
Lines changed: 158 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,114 @@
1-
from spacy_transformers.layers.transformer_model import *
2-
3-
from ginza_transformers.util import huggingface_from_pretrained_custom
4-
5-
6-
def TransformerModelCustom(
7-
name: str, get_spans: Callable, tokenizer_config: dict
8-
) -> Model[List[Doc], FullTransformerBatch]:
9-
return Model(
10-
"transformer",
11-
forward,
12-
init=init_custom,
13-
layers=[],
14-
dims={"nO": None},
15-
attrs={
16-
"tokenizer": None,
17-
"get_spans": get_spans,
18-
"name": name,
19-
"tokenizer_config": tokenizer_config,
20-
"set_transformer": set_pytorch_transformer,
21-
"has_transformer": False,
22-
"flush_cache_chance": 0.0,
23-
},
24-
)
1+
import copy
2+
import sys
3+
from typing import Callable, Dict, Optional, Tuple, Union
4+
from pathlib import Path
5+
6+
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizerBase
7+
8+
from thinc.api import CupyOps, Model, get_current_ops
9+
10+
from spacy_transformers.align import get_alignment
11+
from spacy_transformers.data_classes import WordpieceBatch, HFObjects
12+
from spacy_transformers.layers._util import replace_listener, replace_listener_cfg
13+
from spacy_transformers.layers.hf_wrapper import HFWrapper
14+
from spacy_transformers.layers.transformer_model import (
15+
TransformerModel,
16+
_convert_transformer_inputs,
17+
_convert_transformer_outputs,
18+
forward,
19+
huggingface_tokenize,
20+
set_pytorch_transformer,
21+
)
22+
from spacy_transformers.truncate import truncate_oversize_splits
23+
24+
25+
class TransformerModelCustom(Model):
26+
def __init__(
27+
self,
28+
name: str,
29+
get_spans: Callable,
30+
tokenizer_config: dict = {},
31+
transformer_config: dict = {},
32+
mixed_precision: bool = False,
33+
grad_scaler_config: dict = {},
34+
):
35+
"""
36+
get_spans (Callable[[List[Doc]], List[Span]]):
37+
A function to extract spans from the batch of Doc objects.
38+
This is used to manage long documents, by cutting them into smaller
39+
sequences before running the transformer. The spans are allowed to
40+
overlap, and you can also omit sections of the Doc if they are not
41+
relevant.
42+
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
43+
transformer_config (dict): Settings to pass to the transformers forward pass.
44+
"""
45+
hf_model = HFObjects(None, None, None, tokenizer_config, transformer_config)
46+
wrapper = HFWrapper(
47+
hf_model,
48+
convert_inputs=_convert_transformer_inputs,
49+
convert_outputs=_convert_transformer_outputs,
50+
mixed_precision=mixed_precision,
51+
grad_scaler_config=grad_scaler_config,
52+
)
53+
super().__init__(
54+
"transformer",
55+
forward,
56+
init=init_custom,
57+
layers=[wrapper],
58+
dims={"nO": None},
59+
attrs={
60+
"get_spans": get_spans,
61+
"name": name,
62+
"set_transformer": set_pytorch_transformer,
63+
"has_transformer": False,
64+
"flush_cache_chance": 0.0,
65+
"replace_listener": replace_listener,
66+
"replace_listener_cfg": replace_listener_cfg,
67+
},
68+
)
69+
70+
@property
71+
def tokenizer(self):
72+
return self.layers[0].shims[0]._hfmodel.tokenizer
73+
74+
@property
75+
def transformer(self):
76+
return self.layers[0].shims[0]._hfmodel.transformer
77+
78+
@property
79+
def _init_tokenizer_config(self):
80+
return self.layers[0].shims[0]._hfmodel._init_tokenizer_config
81+
82+
@property
83+
def _init_transformer_config(self):
84+
return self.layers[0].shims[0]._hfmodel._init_transformer_config
85+
86+
def copy(self):
87+
"""
88+
Create a copy of the model, its attributes, and its parameters. Any child
89+
layers will also be deep-copied. The copy will receive a distinct `model.id`
90+
value.
91+
"""
92+
copied = TransformerModel(self.name, self.attrs["get_spans"])
93+
params = {}
94+
for name in self.param_names:
95+
params[name] = self.get_param(name) if self.has_param(name) else None
96+
copied.params = copy.deepcopy(params)
97+
copied.dims = copy.deepcopy(self._dims)
98+
copied.layers[0] = copy.deepcopy(self.layers[0])
99+
for name in self.grad_names:
100+
copied.set_grad(name, self.get_grad(name).copy())
101+
return copied
25102

26103

27104
def init_custom(model: Model, X=None, Y=None):
28105
if model.attrs["has_transformer"]:
29106
return
30107
name = model.attrs["name"]
31-
tok_cfg = model.attrs["tokenizer_config"]
32-
tokenizer, transformer = huggingface_from_pretrained_custom(name, tok_cfg)
33-
model.attrs["tokenizer"] = tokenizer
34-
model.attrs["set_transformer"](model, transformer)
108+
tok_cfg = model._init_tokenizer_config
109+
trf_cfg = model._init_transformer_config
110+
tokenizer, hf_model = huggingface_from_pretrained_custom(name, tok_cfg, trf_cfg, model.attrs["name"])
111+
model.attrs["set_transformer"](model, hf_model)
35112
# Call the model with a batch of inputs to infer the width
36113
if X:
37114
# If we're dealing with actual texts, do the work to setup the wordpieces
@@ -42,26 +119,66 @@ def init_custom(model: Model, X=None, Y=None):
42119
flat_spans = []
43120
for doc_spans in nested_spans:
44121
flat_spans.extend(doc_spans)
45-
token_data = huggingface_tokenize(
46-
model.attrs["tokenizer"],
47-
[span.text for span in flat_spans]
48-
)
122+
token_data = huggingface_tokenize(tokenizer, [span.text for span in flat_spans])
49123
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
50124
align = get_alignment(
51-
flat_spans,
52-
wordpieces.strings, model.attrs["tokenizer"].all_special_tokens
125+
flat_spans, wordpieces.strings, tokenizer.all_special_tokens
53126
)
54127
wordpieces, align = truncate_oversize_splits(
55128
wordpieces, align, tokenizer.model_max_length
56129
)
57130
else:
58131
texts = ["hello world", "foo bar"]
59-
token_data = huggingface_tokenize(
60-
model.attrs["tokenizer"],
61-
texts
62-
)
132+
token_data = huggingface_tokenize(tokenizer, texts)
63133
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
64134
model.layers[0].initialize(X=wordpieces)
65-
tensors = model.layers[0].predict(wordpieces)
66-
t_i = find_last_hidden(tensors)
67-
model.set_dim("nO", tensors[t_i].shape[-1])
135+
model_output = model.layers[0].predict(wordpieces)
136+
model.set_dim("nO", model_output.last_hidden_state.shape[-1])
137+
138+
139+
def huggingface_from_pretrained_custom(
140+
source: Union[Path, str], tok_config: Dict, trf_config: Dict, model_name: Optional[str] = None,
141+
) -> Tuple[PreTrainedTokenizerBase, HFObjects]:
142+
"""Create a Huggingface transformer model from pretrained weights. Will
143+
download the model if it is not already downloaded.
144+
145+
source (Union[str, Path]): The name of the model or a path to it, such as
146+
'bert-base-cased'.
147+
tok_config (dict): Settings to pass to the tokenizer.
148+
trf_config (dict): Settings to pass to the transformer.
149+
"""
150+
if hasattr(source, "absolute"):
151+
str_path = str(source.absolute())
152+
else:
153+
str_path = source
154+
155+
try:
156+
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
157+
except ValueError as e:
158+
if "tokenizer_class" not in tok_config:
159+
raise e
160+
tokenizer_class_name = tok_config["tokenizer_class"].split(".")
161+
from importlib import import_module
162+
tokenizer_module = import_module(".".join(tokenizer_class_name[:-1]))
163+
tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1])
164+
tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tok_config)
165+
vocab_file_contents = None
166+
if hasattr(tokenizer, "vocab_file"):
167+
with open(tokenizer.vocab_file, "rb") as fileh:
168+
vocab_file_contents = fileh.read()
169+
170+
try:
171+
trf_config["return_dict"] = True
172+
config = AutoConfig.from_pretrained(str_path, **trf_config)
173+
transformer = AutoModel.from_pretrained(str_path, config=config)
174+
except OSError as e:
175+
try:
176+
transformer = AutoModel.from_pretrained(model_name, local_files_only=True)
177+
except OSError as e2:
178+
print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr)
179+
transformer = AutoModel.from_pretrained(model_name)
180+
print("succeded", file=sys.stderr)
181+
ops = get_current_ops()
182+
if isinstance(ops, CupyOps):
183+
transformer.cuda()
184+
return tokenizer, HFObjects(tokenizer, transformer, vocab_file_contents)

ginza_transformers/pipeline_component.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from typing import List, Callable, Iterable, Union
21
from pathlib import Path
2+
import sys
3+
from typing import List, Callable, Iterable, Union
34

45
from spacy.language import Language
56
from spacy.pipeline.pipe import deserialize_config
67
from spacy.tokens import Doc
7-
from spacy import util
8+
from spacy import util, Errors
89
from thinc.api import Model, Config
910

1011
from spacy_transformers.data_classes import FullTransformerBatch
1112
from spacy_transformers.pipeline_component import Transformer, DOC_EXT_ATTR
1213

13-
from ginza_transformers.util import huggingface_from_pretrained_custom
14+
from .layers.transformer_model import huggingface_from_pretrained_custom
1415

1516

1617
DEFAULT_CONFIG_STR = """
@@ -60,12 +61,21 @@ def from_disk(
6061
) -> "TransformerCustom":
6162

6263
def load_model(p):
63-
p = Path(p).absolute()
64-
tokenizer, transformer = huggingface_from_pretrained_custom(
65-
p, self.model.attrs["tokenizer_config"], self.model.attrs["name"]
66-
)
67-
self.model.attrs["tokenizer"] = tokenizer
68-
self.model.attrs["set_transformer"](self.model, transformer)
64+
try:
65+
with open(p, "rb") as mfile:
66+
self.model.from_bytes(mfile.read())
67+
except AttributeError:
68+
raise ValueError(Errors.E149) from None
69+
except (IsADirectoryError, PermissionError):
70+
p = Path(p).absolute()
71+
tokenizer, hf_model = huggingface_from_pretrained_custom(
72+
p,
73+
self.model._init_tokenizer_config,
74+
self.model._init_transformer_config,
75+
self.model.attrs["name"],
76+
)
77+
self.model.attrs["tokenizer"] = tokenizer
78+
self.model.attrs["set_transformer"](self.model, hf_model)
6979

7080
deserialize = {
7181
"vocab": self.vocab.from_disk,

ginza_transformers/util.py

Lines changed: 0 additions & 41 deletions
This file was deleted.

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
],
1515
},
1616
install_requires=[
17-
"spacy-transformers>=1.0.4,<1.1.0",
17+
"spacy-transformers>=1.1.2,<1.2.0",
1818
],
1919
license="MIT",
2020
name="ginza-transformers",
2121
packages=find_packages(include=["ginza_transformers", "ginza_transformers.layers"]),
2222
url="https://github.com/megagonlabs/ginza-transformers",
23-
version='0.3.2',
23+
version='0.4.0',
2424
)

0 commit comments

Comments
 (0)