Skip to content

Commit d295416

Browse files
change strategy: overriding HFShim.from_bytes
1 parent 601738c commit d295416

File tree

3 files changed

+111
-209
lines changed

3 files changed

+111
-209
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import sys
2+
from typing import Any
3+
from io import BytesIO
4+
from pathlib import Path
5+
import srsly
6+
import torch
7+
from spacy.util import SimpleFrozenDict
8+
from spacy.vectors import get_current_ops
9+
10+
from spacy_transformers.layers import hf_shim
11+
from spacy_transformers.layers.hf_shim import HFShim
12+
from spacy_transformers.data_classes import HFObjects
13+
from spacy_transformers.util import make_tempdir
14+
15+
from transformers import AutoModel, AutoConfig, AutoTokenizer
16+
17+
18+
def override_hf_shims_from_bytes():
19+
assert hf_shim.HFShim.from_bytes is not HFShimCustom.from_bytes
20+
origin = hf_shim.HFShim.from_bytes
21+
hf_shim.HFShim.from_bytes = HFShimCustom.from_bytes
22+
return origin
23+
24+
def recover_hf_shims_from_bytes(origin):
25+
assert hf_shim.HFShim.from_bytes is HFShimCustom.from_bytes
26+
hf_shim.HFShim.from_bytes = origin
27+
28+
29+
class HFShimCustom(HFShim):
30+
31+
def from_bytes(self, bytes_data):
32+
msg = srsly.msgpack_loads(bytes_data)
33+
config_dict = msg["config"]
34+
tok_dict = msg["tokenizer"]
35+
if config_dict:
36+
with make_tempdir() as temp_dir:
37+
config_file = temp_dir / "config.json"
38+
srsly.write_json(config_file, config_dict)
39+
config = AutoConfig.from_pretrained(config_file)
40+
for x, x_bytes in tok_dict.items():
41+
Path(temp_dir / x).write_bytes(x_bytes)
42+
tokenizer = None
43+
try:
44+
tokenizer = AutoTokenizer.from_pretrained(str(temp_dir.absolute()))
45+
except (ValueError, OSError):
46+
pass
47+
if tokenizer is None:
48+
tok_config = srsly.read_json(str((temp_dir / "tokenizer_config.json").absolute()))
49+
tokenizer_class_name = tok_config["tokenizer_class"].split(".")
50+
if tokenizer_class_name == ["ElectraSudachipyTokenizer"]:
51+
from sudachitra.tokenization_electra_sudachipy import ElectraSudachipyTokenizer as tokenizer_class
52+
tokenizer = tokenizer_class(vocab_file=str((temp_dir / "vocab.txt").absolute()), **tok_config)
53+
else:
54+
from importlib import import_module
55+
tokenizer_module = import_module(".".join(tokenizer_class_name[:-1]))
56+
tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1])
57+
58+
vocab_file_contents = None
59+
if hasattr(tokenizer, "vocab_file"):
60+
vocab_file_name = tokenizer.vocab_files_names["vocab_file"]
61+
vocab_file_path = str((temp_dir / vocab_file_name).absolute())
62+
with open(vocab_file_path, "rb") as fileh:
63+
vocab_file_contents = fileh.read()
64+
65+
try:
66+
transformer = AutoModel.from_config(config)
67+
except OSError as e:
68+
try:
69+
transformer = AutoModel.from_pretrained(config["_name_or_path"], local_files_only=True)
70+
except OSError as e2:
71+
print("trying to download model from huggingface hub:", config["_name_or_path"], "...", file=sys.stderr)
72+
transformer = AutoModel.from_pretrained(config["_name_or_path"])
73+
print("succeded", file=sys.stderr)
74+
75+
self._hfmodel = HFObjects(
76+
tokenizer,
77+
transformer,
78+
vocab_file_contents,
79+
SimpleFrozenDict(),
80+
SimpleFrozenDict(),
81+
)
82+
self._model = transformer
83+
filelike = BytesIO(msg["state"])
84+
filelike.seek(0)
85+
ops = get_current_ops()
86+
if ops.device_type == "cpu":
87+
map_location = "cpu"
88+
else: # pragma: no cover
89+
device_id = torch.cuda.current_device()
90+
map_location = f"cuda:{device_id}"
91+
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
92+
self._model.to(map_location)
93+
else:
94+
self._hfmodel = HFObjects(
95+
None,
96+
None,
97+
None,
98+
msg["_init_tokenizer_config"],
99+
msg["_init_transformer_config"],
100+
)
101+
return self
Lines changed: 3 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -1,184 +1,5 @@
1-
import copy
2-
import sys
3-
from typing import Callable, Dict, Optional, Tuple, Union
4-
from pathlib import Path
1+
from spacy_transformers.layers.transformer_model import TransformerModel
52

6-
from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizerBase
73

8-
from thinc.api import CupyOps, Model, get_current_ops
9-
10-
from spacy_transformers.align import get_alignment
11-
from spacy_transformers.data_classes import WordpieceBatch, HFObjects
12-
from spacy_transformers.layers._util import replace_listener, replace_listener_cfg
13-
from spacy_transformers.layers.hf_wrapper import HFWrapper
14-
from spacy_transformers.layers.transformer_model import (
15-
TransformerModel,
16-
_convert_transformer_inputs,
17-
_convert_transformer_outputs,
18-
forward,
19-
huggingface_tokenize,
20-
set_pytorch_transformer,
21-
)
22-
from spacy_transformers.truncate import truncate_oversize_splits
23-
24-
25-
class TransformerModelCustom(Model):
26-
def __init__(
27-
self,
28-
name: str,
29-
get_spans: Callable,
30-
tokenizer_config: dict = {},
31-
transformer_config: dict = {},
32-
mixed_precision: bool = False,
33-
grad_scaler_config: dict = {},
34-
):
35-
"""
36-
get_spans (Callable[[List[Doc]], List[Span]]):
37-
A function to extract spans from the batch of Doc objects.
38-
This is used to manage long documents, by cutting them into smaller
39-
sequences before running the transformer. The spans are allowed to
40-
overlap, and you can also omit sections of the Doc if they are not
41-
relevant.
42-
tokenizer_config (dict): Settings to pass to the transformers tokenizer.
43-
transformer_config (dict): Settings to pass to the transformers forward pass.
44-
"""
45-
hf_model = HFObjects(None, None, None, tokenizer_config, transformer_config)
46-
wrapper = HFWrapper(
47-
hf_model,
48-
convert_inputs=_convert_transformer_inputs,
49-
convert_outputs=_convert_transformer_outputs,
50-
mixed_precision=mixed_precision,
51-
grad_scaler_config=grad_scaler_config,
52-
)
53-
super().__init__(
54-
"transformer",
55-
forward,
56-
init=init_custom,
57-
layers=[wrapper],
58-
dims={"nO": None},
59-
attrs={
60-
"get_spans": get_spans,
61-
"name": name,
62-
"set_transformer": set_pytorch_transformer,
63-
"has_transformer": False,
64-
"flush_cache_chance": 0.0,
65-
"replace_listener": replace_listener,
66-
"replace_listener_cfg": replace_listener_cfg,
67-
},
68-
)
69-
70-
@property
71-
def tokenizer(self):
72-
return self.layers[0].shims[0]._hfmodel.tokenizer
73-
74-
@property
75-
def transformer(self):
76-
return self.layers[0].shims[0]._hfmodel.transformer
77-
78-
@property
79-
def _init_tokenizer_config(self):
80-
return self.layers[0].shims[0]._hfmodel._init_tokenizer_config
81-
82-
@property
83-
def _init_transformer_config(self):
84-
return self.layers[0].shims[0]._hfmodel._init_transformer_config
85-
86-
def copy(self):
87-
"""
88-
Create a copy of the model, its attributes, and its parameters. Any child
89-
layers will also be deep-copied. The copy will receive a distinct `model.id`
90-
value.
91-
"""
92-
copied = TransformerModel(self.name, self.attrs["get_spans"])
93-
params = {}
94-
for name in self.param_names:
95-
params[name] = self.get_param(name) if self.has_param(name) else None
96-
copied.params = copy.deepcopy(params)
97-
copied.dims = copy.deepcopy(self._dims)
98-
copied.layers[0] = copy.deepcopy(self.layers[0])
99-
for name in self.grad_names:
100-
copied.set_grad(name, self.get_grad(name).copy())
101-
return copied
102-
103-
104-
def init_custom(model: Model, X=None, Y=None):
105-
if model.attrs["has_transformer"]:
106-
return
107-
name = model.attrs["name"]
108-
tok_cfg = model._init_tokenizer_config
109-
trf_cfg = model._init_transformer_config
110-
tokenizer, hf_model = huggingface_from_pretrained_custom(name, tok_cfg, trf_cfg, model.attrs["name"])
111-
model.attrs["set_transformer"](model, hf_model)
112-
# Call the model with a batch of inputs to infer the width
113-
if X:
114-
# If we're dealing with actual texts, do the work to setup the wordpieces
115-
# batch properly
116-
docs = X
117-
get_spans = model.attrs["get_spans"]
118-
nested_spans = get_spans(docs)
119-
flat_spans = []
120-
for doc_spans in nested_spans:
121-
flat_spans.extend(doc_spans)
122-
token_data = huggingface_tokenize(tokenizer, [span.text for span in flat_spans])
123-
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
124-
align = get_alignment(
125-
flat_spans, wordpieces.strings, tokenizer.all_special_tokens
126-
)
127-
wordpieces, align = truncate_oversize_splits(
128-
wordpieces, align, tokenizer.model_max_length
129-
)
130-
else:
131-
texts = ["hello world", "foo bar"]
132-
token_data = huggingface_tokenize(tokenizer, texts)
133-
wordpieces = WordpieceBatch.from_batch_encoding(token_data)
134-
model.layers[0].initialize(X=wordpieces)
135-
model_output = model.layers[0].predict(wordpieces)
136-
model.set_dim("nO", model_output.last_hidden_state.shape[-1])
137-
138-
139-
def huggingface_from_pretrained_custom(
140-
source: Union[Path, str], tok_config: Dict, trf_config: Dict, model_name: Optional[str] = None,
141-
) -> Tuple[PreTrainedTokenizerBase, HFObjects]:
142-
"""Create a Huggingface transformer model from pretrained weights. Will
143-
download the model if it is not already downloaded.
144-
145-
source (Union[str, Path]): The name of the model or a path to it, such as
146-
'bert-base-cased'.
147-
tok_config (dict): Settings to pass to the tokenizer.
148-
trf_config (dict): Settings to pass to the transformer.
149-
"""
150-
if hasattr(source, "absolute"):
151-
str_path = str(source.absolute())
152-
else:
153-
str_path = source
154-
155-
try:
156-
tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config)
157-
except ValueError as e:
158-
if "tokenizer_class" not in tok_config:
159-
raise e
160-
tokenizer_class_name = tok_config["tokenizer_class"].split(".")
161-
from importlib import import_module
162-
tokenizer_module = import_module(".".join(tokenizer_class_name[:-1]))
163-
tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1])
164-
tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tok_config)
165-
vocab_file_contents = None
166-
if hasattr(tokenizer, "vocab_file"):
167-
with open(tokenizer.vocab_file, "rb") as fileh:
168-
vocab_file_contents = fileh.read()
169-
170-
try:
171-
trf_config["return_dict"] = True
172-
config = AutoConfig.from_pretrained(str_path, **trf_config)
173-
transformer = AutoModel.from_pretrained(str_path, config=config)
174-
except OSError as e:
175-
try:
176-
transformer = AutoModel.from_pretrained(model_name, local_files_only=True)
177-
except OSError as e2:
178-
print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr)
179-
transformer = AutoModel.from_pretrained(model_name)
180-
print("succeded", file=sys.stderr)
181-
ops = get_current_ops()
182-
if isinstance(ops, CupyOps):
183-
transformer.cuda()
184-
return tokenizer, HFObjects(tokenizer, transformer, vocab_file_contents)
4+
class TransformerModelCustom(TransformerModel):
5+
pass

ginza_transformers/pipeline_component.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@
33
from typing import List, Callable, Iterable, Union
44

55
from spacy.language import Language
6-
from spacy.pipeline.pipe import deserialize_config
76
from spacy.tokens import Doc
8-
from spacy import util, Errors
97
from thinc.api import Model, Config
108

119
from spacy_transformers.data_classes import FullTransformerBatch
1210
from spacy_transformers.pipeline_component import Transformer, DOC_EXT_ATTR
1311

14-
from .layers.transformer_model import huggingface_from_pretrained_custom
12+
from .layers.hf_shim_custom import override_hf_shims_from_bytes, recover_hf_shims_from_bytes
1513

1614

1715
DEFAULT_CONFIG_STR = """
@@ -22,7 +20,7 @@
2220
@annotation_setters = "spacy-transformers.null_annotation_setter.v1"
2321
2422
[transformer_custom.model]
25-
@architectures = "ginza-transformers.TransformerModel.v1"
23+
@architectures = "ginza-transformers.TransformerModel.v3"
2624
2725
[transformer_custom.model.get_spans]
2826
@span_getters = "spacy-transformers.strided_spans.v1"
@@ -60,27 +58,9 @@ def from_disk(
6058
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
6159
) -> "TransformerCustom":
6260

63-
def load_model(p):
64-
try:
65-
with open(p, "rb") as mfile:
66-
self.model.from_bytes(mfile.read())
67-
except AttributeError:
68-
raise ValueError(Errors.E149) from None
69-
except (IsADirectoryError, PermissionError):
70-
p = Path(p).absolute()
71-
tokenizer, hf_model = huggingface_from_pretrained_custom(
72-
p,
73-
self.model._init_tokenizer_config,
74-
self.model._init_transformer_config,
75-
self.model.attrs["name"],
76-
)
77-
self.model.attrs["tokenizer"] = tokenizer
78-
self.model.attrs["set_transformer"](self.model, hf_model)
79-
80-
deserialize = {
81-
"vocab": self.vocab.from_disk,
82-
"cfg": lambda p: self.cfg.update(deserialize_config(p)),
83-
"model": load_model,
84-
}
85-
util.from_disk(path, deserialize, exclude)
61+
origin = override_hf_shims_from_bytes()
62+
try:
63+
super().from_disk(path, exclude=exclude)
64+
finally:
65+
recover_hf_shims_from_bytes(origin)
8666
return self

0 commit comments

Comments
 (0)