|
1 |
| -import copy |
2 |
| -import sys |
3 |
| -from typing import Callable, Dict, Optional, Tuple, Union |
4 |
| -from pathlib import Path |
| 1 | +from spacy_transformers.layers.transformer_model import TransformerModel |
5 | 2 |
|
6 |
| -from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedTokenizerBase |
7 | 3 |
|
8 |
| -from thinc.api import CupyOps, Model, get_current_ops |
9 |
| - |
10 |
| -from spacy_transformers.align import get_alignment |
11 |
| -from spacy_transformers.data_classes import WordpieceBatch, HFObjects |
12 |
| -from spacy_transformers.layers._util import replace_listener, replace_listener_cfg |
13 |
| -from spacy_transformers.layers.hf_wrapper import HFWrapper |
14 |
| -from spacy_transformers.layers.transformer_model import ( |
15 |
| - TransformerModel, |
16 |
| - _convert_transformer_inputs, |
17 |
| - _convert_transformer_outputs, |
18 |
| - forward, |
19 |
| - huggingface_tokenize, |
20 |
| - set_pytorch_transformer, |
21 |
| -) |
22 |
| -from spacy_transformers.truncate import truncate_oversize_splits |
23 |
| - |
24 |
| - |
25 |
| -class TransformerModelCustom(Model): |
26 |
| - def __init__( |
27 |
| - self, |
28 |
| - name: str, |
29 |
| - get_spans: Callable, |
30 |
| - tokenizer_config: dict = {}, |
31 |
| - transformer_config: dict = {}, |
32 |
| - mixed_precision: bool = False, |
33 |
| - grad_scaler_config: dict = {}, |
34 |
| - ): |
35 |
| - """ |
36 |
| - get_spans (Callable[[List[Doc]], List[Span]]): |
37 |
| - A function to extract spans from the batch of Doc objects. |
38 |
| - This is used to manage long documents, by cutting them into smaller |
39 |
| - sequences before running the transformer. The spans are allowed to |
40 |
| - overlap, and you can also omit sections of the Doc if they are not |
41 |
| - relevant. |
42 |
| - tokenizer_config (dict): Settings to pass to the transformers tokenizer. |
43 |
| - transformer_config (dict): Settings to pass to the transformers forward pass. |
44 |
| - """ |
45 |
| - hf_model = HFObjects(None, None, None, tokenizer_config, transformer_config) |
46 |
| - wrapper = HFWrapper( |
47 |
| - hf_model, |
48 |
| - convert_inputs=_convert_transformer_inputs, |
49 |
| - convert_outputs=_convert_transformer_outputs, |
50 |
| - mixed_precision=mixed_precision, |
51 |
| - grad_scaler_config=grad_scaler_config, |
52 |
| - ) |
53 |
| - super().__init__( |
54 |
| - "transformer", |
55 |
| - forward, |
56 |
| - init=init_custom, |
57 |
| - layers=[wrapper], |
58 |
| - dims={"nO": None}, |
59 |
| - attrs={ |
60 |
| - "get_spans": get_spans, |
61 |
| - "name": name, |
62 |
| - "set_transformer": set_pytorch_transformer, |
63 |
| - "has_transformer": False, |
64 |
| - "flush_cache_chance": 0.0, |
65 |
| - "replace_listener": replace_listener, |
66 |
| - "replace_listener_cfg": replace_listener_cfg, |
67 |
| - }, |
68 |
| - ) |
69 |
| - |
70 |
| - @property |
71 |
| - def tokenizer(self): |
72 |
| - return self.layers[0].shims[0]._hfmodel.tokenizer |
73 |
| - |
74 |
| - @property |
75 |
| - def transformer(self): |
76 |
| - return self.layers[0].shims[0]._hfmodel.transformer |
77 |
| - |
78 |
| - @property |
79 |
| - def _init_tokenizer_config(self): |
80 |
| - return self.layers[0].shims[0]._hfmodel._init_tokenizer_config |
81 |
| - |
82 |
| - @property |
83 |
| - def _init_transformer_config(self): |
84 |
| - return self.layers[0].shims[0]._hfmodel._init_transformer_config |
85 |
| - |
86 |
| - def copy(self): |
87 |
| - """ |
88 |
| - Create a copy of the model, its attributes, and its parameters. Any child |
89 |
| - layers will also be deep-copied. The copy will receive a distinct `model.id` |
90 |
| - value. |
91 |
| - """ |
92 |
| - copied = TransformerModel(self.name, self.attrs["get_spans"]) |
93 |
| - params = {} |
94 |
| - for name in self.param_names: |
95 |
| - params[name] = self.get_param(name) if self.has_param(name) else None |
96 |
| - copied.params = copy.deepcopy(params) |
97 |
| - copied.dims = copy.deepcopy(self._dims) |
98 |
| - copied.layers[0] = copy.deepcopy(self.layers[0]) |
99 |
| - for name in self.grad_names: |
100 |
| - copied.set_grad(name, self.get_grad(name).copy()) |
101 |
| - return copied |
102 |
| - |
103 |
| - |
104 |
| -def init_custom(model: Model, X=None, Y=None): |
105 |
| - if model.attrs["has_transformer"]: |
106 |
| - return |
107 |
| - name = model.attrs["name"] |
108 |
| - tok_cfg = model._init_tokenizer_config |
109 |
| - trf_cfg = model._init_transformer_config |
110 |
| - tokenizer, hf_model = huggingface_from_pretrained_custom(name, tok_cfg, trf_cfg, model.attrs["name"]) |
111 |
| - model.attrs["set_transformer"](model, hf_model) |
112 |
| - # Call the model with a batch of inputs to infer the width |
113 |
| - if X: |
114 |
| - # If we're dealing with actual texts, do the work to setup the wordpieces |
115 |
| - # batch properly |
116 |
| - docs = X |
117 |
| - get_spans = model.attrs["get_spans"] |
118 |
| - nested_spans = get_spans(docs) |
119 |
| - flat_spans = [] |
120 |
| - for doc_spans in nested_spans: |
121 |
| - flat_spans.extend(doc_spans) |
122 |
| - token_data = huggingface_tokenize(tokenizer, [span.text for span in flat_spans]) |
123 |
| - wordpieces = WordpieceBatch.from_batch_encoding(token_data) |
124 |
| - align = get_alignment( |
125 |
| - flat_spans, wordpieces.strings, tokenizer.all_special_tokens |
126 |
| - ) |
127 |
| - wordpieces, align = truncate_oversize_splits( |
128 |
| - wordpieces, align, tokenizer.model_max_length |
129 |
| - ) |
130 |
| - else: |
131 |
| - texts = ["hello world", "foo bar"] |
132 |
| - token_data = huggingface_tokenize(tokenizer, texts) |
133 |
| - wordpieces = WordpieceBatch.from_batch_encoding(token_data) |
134 |
| - model.layers[0].initialize(X=wordpieces) |
135 |
| - model_output = model.layers[0].predict(wordpieces) |
136 |
| - model.set_dim("nO", model_output.last_hidden_state.shape[-1]) |
137 |
| - |
138 |
| - |
139 |
| -def huggingface_from_pretrained_custom( |
140 |
| - source: Union[Path, str], tok_config: Dict, trf_config: Dict, model_name: Optional[str] = None, |
141 |
| -) -> Tuple[PreTrainedTokenizerBase, HFObjects]: |
142 |
| - """Create a Huggingface transformer model from pretrained weights. Will |
143 |
| - download the model if it is not already downloaded. |
144 |
| -
|
145 |
| - source (Union[str, Path]): The name of the model or a path to it, such as |
146 |
| - 'bert-base-cased'. |
147 |
| - tok_config (dict): Settings to pass to the tokenizer. |
148 |
| - trf_config (dict): Settings to pass to the transformer. |
149 |
| - """ |
150 |
| - if hasattr(source, "absolute"): |
151 |
| - str_path = str(source.absolute()) |
152 |
| - else: |
153 |
| - str_path = source |
154 |
| - |
155 |
| - try: |
156 |
| - tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config) |
157 |
| - except ValueError as e: |
158 |
| - if "tokenizer_class" not in tok_config: |
159 |
| - raise e |
160 |
| - tokenizer_class_name = tok_config["tokenizer_class"].split(".") |
161 |
| - from importlib import import_module |
162 |
| - tokenizer_module = import_module(".".join(tokenizer_class_name[:-1])) |
163 |
| - tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1]) |
164 |
| - tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tok_config) |
165 |
| - vocab_file_contents = None |
166 |
| - if hasattr(tokenizer, "vocab_file"): |
167 |
| - with open(tokenizer.vocab_file, "rb") as fileh: |
168 |
| - vocab_file_contents = fileh.read() |
169 |
| - |
170 |
| - try: |
171 |
| - trf_config["return_dict"] = True |
172 |
| - config = AutoConfig.from_pretrained(str_path, **trf_config) |
173 |
| - transformer = AutoModel.from_pretrained(str_path, config=config) |
174 |
| - except OSError as e: |
175 |
| - try: |
176 |
| - transformer = AutoModel.from_pretrained(model_name, local_files_only=True) |
177 |
| - except OSError as e2: |
178 |
| - print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr) |
179 |
| - transformer = AutoModel.from_pretrained(model_name) |
180 |
| - print("succeded", file=sys.stderr) |
181 |
| - ops = get_current_ops() |
182 |
| - if isinstance(ops, CupyOps): |
183 |
| - transformer.cuda() |
184 |
| - return tokenizer, HFObjects(tokenizer, transformer, vocab_file_contents) |
| 4 | +class TransformerModelCustom(TransformerModel): |
| 5 | + pass |
0 commit comments