|
1 |
| -from spacy_transformers.layers.transformer_model import TransformerModel |
| 1 | +from pathlib import Path |
| 2 | +import sys |
| 3 | +from typing import Union, Dict |
2 | 4 |
|
| 5 | +from transformers import AutoConfig, AutoModel, AutoTokenizer |
3 | 6 |
|
4 |
| -class TransformerModelCustom(TransformerModel): |
5 |
| - pass |
| 7 | +from spacy_transformers.layers import transformer_model |
| 8 | +from spacy_transformers.data_classes import HFObjects |
| 9 | + |
| 10 | +from thinc.api import get_current_ops, CupyOps |
| 11 | + |
| 12 | + |
| 13 | +def override_huggingface_from_pretrained(): |
| 14 | + assert transformer_model.huggingface_from_pretrained is not huggingface_from_pretrained_custom |
| 15 | + origin = transformer_model.huggingface_from_pretrained |
| 16 | + transformer_model.huggingface_from_pretrained = huggingface_from_pretrained_custom |
| 17 | + return origin |
| 18 | + |
| 19 | +def recover_huggingface_from_pretrained(origin): |
| 20 | + assert transformer_model.huggingface_from_pretrained is huggingface_from_pretrained_custom |
| 21 | + transformer_model.huggingface_from_pretrained = origin |
| 22 | + |
| 23 | + |
| 24 | +def huggingface_from_pretrained_custom( |
| 25 | + source: Union[Path, str], tok_config: Dict, trf_config: Dict |
| 26 | +) -> HFObjects: |
| 27 | + """Create a Huggingface transformer model from pretrained weights. Will |
| 28 | + download the model if it is not already downloaded. |
| 29 | +
|
| 30 | + source (Union[str, Path]): The name of the model or a path to it, such as |
| 31 | + 'bert-base-cased'. |
| 32 | + tok_config (dict): Settings to pass to the tokenizer. |
| 33 | + trf_config (dict): Settings to pass to the transformer. |
| 34 | + """ |
| 35 | + if hasattr(source, "absolute"): |
| 36 | + str_path = str(source.absolute()) |
| 37 | + else: |
| 38 | + str_path = source |
| 39 | + |
| 40 | + try: |
| 41 | + tokenizer = AutoTokenizer.from_pretrained(str_path, **tok_config) |
| 42 | + except ValueError as e: |
| 43 | + if "tokenizer_class" not in tok_config: |
| 44 | + raise e |
| 45 | + tokenizer_class_name = tok_config["tokenizer_class"].split(".") |
| 46 | + from importlib import import_module |
| 47 | + tokenizer_module = import_module(".".join(tokenizer_class_name[:-1])) |
| 48 | + tokenizer_class = getattr(tokenizer_module, tokenizer_class_name[-1]) |
| 49 | + tokenizer = tokenizer_class(vocab_file=str_path + "/vocab.txt", **tok_config) |
| 50 | + vocab_file_contents = None |
| 51 | + if hasattr(tokenizer, "vocab_file"): |
| 52 | + with open(tokenizer.vocab_file, "rb") as fileh: |
| 53 | + vocab_file_contents = fileh.read() |
| 54 | + |
| 55 | + try: |
| 56 | + trf_config["return_dict"] = True |
| 57 | + config = AutoConfig.from_pretrained(str_path, **trf_config) |
| 58 | + transformer = AutoModel.from_pretrained(str_path, config=config) |
| 59 | + except OSError as e: |
| 60 | + try: |
| 61 | + transformer = AutoModel.from_pretrained(str_path, local_files_only=True) |
| 62 | + except OSError as e2: |
| 63 | + model_name = str(source) |
| 64 | + print("trying to download model from huggingface hub:", model_name, "...", file=sys.stderr) |
| 65 | + transformer = AutoModel.from_pretrained(model_name) |
| 66 | + print("succeded", file=sys.stderr) |
| 67 | + ops = get_current_ops() |
| 68 | + if isinstance(ops, CupyOps): |
| 69 | + transformer.cuda() |
| 70 | + return HFObjects(tokenizer, transformer, vocab_file_contents) |
0 commit comments