1
- from spacy_transformers .layers .transformer_model import *
2
-
3
- from ginza_transformers .util import huggingface_from_pretrained_custom
4
-
5
-
6
- def TransformerModelCustom (
7
- name : str , get_spans : Callable , tokenizer_config : dict
8
- ) -> Model [List [Doc ], FullTransformerBatch ]:
9
- return Model (
10
- "transformer" ,
11
- forward ,
12
- init = init_custom ,
13
- layers = [],
14
- dims = {"nO" : None },
15
- attrs = {
16
- "tokenizer" : None ,
17
- "get_spans" : get_spans ,
18
- "name" : name ,
19
- "tokenizer_config" : tokenizer_config ,
20
- "set_transformer" : set_pytorch_transformer ,
21
- "has_transformer" : False ,
22
- "flush_cache_chance" : 0.0 ,
23
- },
24
- )
1
+ import copy
2
+ import sys
3
+ from typing import Callable , Dict , Optional , Tuple , Union
4
+ from pathlib import Path
5
+
6
+ from transformers import AutoConfig , AutoModel , AutoTokenizer , PreTrainedTokenizerBase
7
+
8
+ from thinc .api import CupyOps , Model , get_current_ops
9
+
10
+ from spacy_transformers .align import get_alignment
11
+ from spacy_transformers .data_classes import WordpieceBatch , HFObjects
12
+ from spacy_transformers .layers ._util import replace_listener , replace_listener_cfg
13
+ from spacy_transformers .layers .hf_wrapper import HFWrapper
14
+ from spacy_transformers .layers .transformer_model import (
15
+ TransformerModel ,
16
+ _convert_transformer_inputs ,
17
+ _convert_transformer_outputs ,
18
+ forward ,
19
+ huggingface_tokenize ,
20
+ set_pytorch_transformer ,
21
+ )
22
+ from spacy_transformers .truncate import truncate_oversize_splits
23
+
24
+
25
+ class TransformerModelCustom (Model ):
26
+ def __init__ (
27
+ self ,
28
+ name : str ,
29
+ get_spans : Callable ,
30
+ tokenizer_config : dict = {},
31
+ transformer_config : dict = {},
32
+ mixed_precision : bool = False ,
33
+ grad_scaler_config : dict = {},
34
+ ):
35
+ """
36
+ get_spans (Callable[[List[Doc]], List[Span]]):
37
+ A function to extract spans from the batch of Doc objects.
38
+ This is used to manage long documents, by cutting them into smaller
39
+ sequences before running the transformer. The spans are allowed to
40
+ overlap, and you can also omit sections of the Doc if they are not
41
+ relevant.
42
+ tokenizer_config (dict): Settings to pass to the transformers tokenizer.
43
+ transformer_config (dict): Settings to pass to the transformers forward pass.
44
+ """
45
+ hf_model = HFObjects (None , None , None , tokenizer_config , transformer_config )
46
+ wrapper = HFWrapper (
47
+ hf_model ,
48
+ convert_inputs = _convert_transformer_inputs ,
49
+ convert_outputs = _convert_transformer_outputs ,
50
+ mixed_precision = mixed_precision ,
51
+ grad_scaler_config = grad_scaler_config ,
52
+ )
53
+ super ().__init__ (
54
+ "transformer" ,
55
+ forward ,
56
+ init = init_custom ,
57
+ layers = [wrapper ],
58
+ dims = {"nO" : None },
59
+ attrs = {
60
+ "get_spans" : get_spans ,
61
+ "name" : name ,
62
+ "set_transformer" : set_pytorch_transformer ,
63
+ "has_transformer" : False ,
64
+ "flush_cache_chance" : 0.0 ,
65
+ "replace_listener" : replace_listener ,
66
+ "replace_listener_cfg" : replace_listener_cfg ,
67
+ },
68
+ )
69
+
70
+ @property
71
+ def tokenizer (self ):
72
+ return self .layers [0 ].shims [0 ]._hfmodel .tokenizer
73
+
74
+ @property
75
+ def transformer (self ):
76
+ return self .layers [0 ].shims [0 ]._hfmodel .transformer
77
+
78
+ @property
79
+ def _init_tokenizer_config (self ):
80
+ return self .layers [0 ].shims [0 ]._hfmodel ._init_tokenizer_config
81
+
82
+ @property
83
+ def _init_transformer_config (self ):
84
+ return self .layers [0 ].shims [0 ]._hfmodel ._init_transformer_config
85
+
86
+ def copy (self ):
87
+ """
88
+ Create a copy of the model, its attributes, and its parameters. Any child
89
+ layers will also be deep-copied. The copy will receive a distinct `model.id`
90
+ value.
91
+ """
92
+ copied = TransformerModel (self .name , self .attrs ["get_spans" ])
93
+ params = {}
94
+ for name in self .param_names :
95
+ params [name ] = self .get_param (name ) if self .has_param (name ) else None
96
+ copied .params = copy .deepcopy (params )
97
+ copied .dims = copy .deepcopy (self ._dims )
98
+ copied .layers [0 ] = copy .deepcopy (self .layers [0 ])
99
+ for name in self .grad_names :
100
+ copied .set_grad (name , self .get_grad (name ).copy ())
101
+ return copied
25
102
26
103
27
104
def init_custom (model : Model , X = None , Y = None ):
28
105
if model .attrs ["has_transformer" ]:
29
106
return
30
107
name = model .attrs ["name" ]
31
- tok_cfg = model .attrs [ "tokenizer_config" ]
32
- tokenizer , transformer = huggingface_from_pretrained_custom ( name , tok_cfg )
33
- model .attrs ["tokenizer" ] = tokenizer
34
- model .attrs ["set_transformer" ](model , transformer )
108
+ tok_cfg = model ._init_tokenizer_config
109
+ trf_cfg = model . _init_transformer_config
110
+ tokenizer , hf_model = huggingface_from_pretrained_custom ( name , tok_cfg , trf_cfg , model .attrs ["name" ])
111
+ model .attrs ["set_transformer" ](model , hf_model )
35
112
# Call the model with a batch of inputs to infer the width
36
113
if X :
37
114
# If we're dealing with actual texts, do the work to setup the wordpieces
@@ -42,26 +119,66 @@ def init_custom(model: Model, X=None, Y=None):
42
119
flat_spans = []
43
120
for doc_spans in nested_spans :
44
121
flat_spans .extend (doc_spans )
45
- token_data = huggingface_tokenize (
46
- model .attrs ["tokenizer" ],
47
- [span .text for span in flat_spans ]
48
- )
122
+ token_data = huggingface_tokenize (tokenizer , [span .text for span in flat_spans ])
49
123
wordpieces = WordpieceBatch .from_batch_encoding (token_data )
50
124
align = get_alignment (
51
- flat_spans ,
52
- wordpieces .strings , model .attrs ["tokenizer" ].all_special_tokens
125
+ flat_spans , wordpieces .strings , tokenizer .all_special_tokens
53
126
)
54
127
wordpieces , align = truncate_oversize_splits (
55
128
wordpieces , align , tokenizer .model_max_length
56
129
)
57
130
else :
58
131
texts = ["hello world" , "foo bar" ]
59
- token_data = huggingface_tokenize (
60
- model .attrs ["tokenizer" ],
61
- texts
62
- )
132
+ token_data = huggingface_tokenize (tokenizer , texts )
63
133
wordpieces = WordpieceBatch .from_batch_encoding (token_data )
64
134
model .layers [0 ].initialize (X = wordpieces )
65
- tensors = model .layers [0 ].predict (wordpieces )
66
- t_i = find_last_hidden (tensors )
67
- model .set_dim ("nO" , tensors [t_i ].shape [- 1 ])
135
+ model_output = model .layers [0 ].predict (wordpieces )
136
+ model .set_dim ("nO" , model_output .last_hidden_state .shape [- 1 ])
137
+
138
+
139
+ def huggingface_from_pretrained_custom (
140
+ source : Union [Path , str ], tok_config : Dict , trf_config : Dict , model_name : Optional [str ] = None ,
141
+ ) -> Tuple [PreTrainedTokenizerBase , HFObjects ]:
142
+ """Create a Huggingface transformer model from pretrained weights. Will
143
+ download the model if it is not already downloaded.
144
+
145
+ source (Union[str, Path]): The name of the model or a path to it, such as
146
+ 'bert-base-cased'.
147
+ tok_config (dict): Settings to pass to the tokenizer.
148
+ trf_config (dict): Settings to pass to the transformer.
149
+ """
150
+ if hasattr (source , "absolute" ):
151
+ str_path = str (source .absolute ())
152
+ else :
153
+ str_path = source
154
+
155
+ try :
156
+ tokenizer = AutoTokenizer .from_pretrained (str_path , ** tok_config )
157
+ except ValueError as e :
158
+ if "tokenizer_class" not in tok_config :
159
+ raise e
160
+ tokenizer_class_name = tok_config ["tokenizer_class" ].split ("." )
161
+ from importlib import import_module
162
+ tokenizer_module = import_module ("." .join (tokenizer_class_name [:- 1 ]))
163
+ tokenizer_class = getattr (tokenizer_module , tokenizer_class_name [- 1 ])
164
+ tokenizer = tokenizer_class (vocab_file = str_path + "/vocab.txt" , ** tok_config )
165
+ vocab_file_contents = None
166
+ if hasattr (tokenizer , "vocab_file" ):
167
+ with open (tokenizer .vocab_file , "rb" ) as fileh :
168
+ vocab_file_contents = fileh .read ()
169
+
170
+ try :
171
+ trf_config ["return_dict" ] = True
172
+ config = AutoConfig .from_pretrained (str_path , ** trf_config )
173
+ transformer = AutoModel .from_pretrained (str_path , config = config )
174
+ except OSError as e :
175
+ try :
176
+ transformer = AutoModel .from_pretrained (model_name , local_files_only = True )
177
+ except OSError as e2 :
178
+ print ("trying to download model from huggingface hub:" , model_name , "..." , file = sys .stderr )
179
+ transformer = AutoModel .from_pretrained (model_name )
180
+ print ("succeded" , file = sys .stderr )
181
+ ops = get_current_ops ()
182
+ if isinstance (ops , CupyOps ):
183
+ transformer .cuda ()
184
+ return tokenizer , HFObjects (tokenizer , transformer , vocab_file_contents )
0 commit comments