|
| 1 | +from typing import List, Callable, Iterable, Union |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +from spacy.language import Language |
| 5 | +from spacy.pipeline.pipe import deserialize_config |
| 6 | +from spacy.tokens import Doc |
| 7 | +from spacy import util |
| 8 | +from thinc.api import Model, Config |
| 9 | + |
| 10 | +from spacy_transformers.data_classes import FullTransformerBatch |
| 11 | +from spacy_transformers.pipeline_component import Transformer, DOC_EXT_ATTR |
| 12 | + |
| 13 | +from ginza_transformers.util import huggingface_from_pretrained_custom |
| 14 | + |
| 15 | + |
| 16 | +DEFAULT_CONFIG_STR = """ |
| 17 | +[transformer_custom] |
| 18 | +max_batch_items = 4096 |
| 19 | +
|
| 20 | +[transformer_custom.set_extra_annotations] |
| 21 | +@annotation_setters = "spacy-transformers.null_annotation_setter.v1" |
| 22 | +
|
| 23 | +[transformer_custom.model] |
| 24 | +@architectures = "ginza-transformers.TransformerModel.v1" |
| 25 | +name = "electra-base-ud-japanese-discriminator" |
| 26 | +tokenizer_config = {"use_fast": false, "tokenizer_class": "sudachitra.tokenization_electra_sudachipy.ElectraSudachipyTokenizer"} |
| 27 | +
|
| 28 | +[transformer_custom.model.get_spans] |
| 29 | +@span_getters = "spacy-transformers.strided_spans.v1" |
| 30 | +window = 128 |
| 31 | +stride = 96 |
| 32 | +""" |
| 33 | + |
| 34 | +DEFAULT_CONFIG = Config().from_str(DEFAULT_CONFIG_STR) |
| 35 | + |
| 36 | + |
| 37 | +@Language.factory( |
| 38 | + "transformer_custom", |
| 39 | + assigns=[f"doc._.{DOC_EXT_ATTR}"], |
| 40 | + default_config=DEFAULT_CONFIG["transformer_custom"], |
| 41 | +) |
| 42 | +def make_transformer_custom( |
| 43 | + nlp: Language, |
| 44 | + name: str, |
| 45 | + model: Model[List[Doc], FullTransformerBatch], |
| 46 | + set_extra_annotations: Callable[[List[Doc], FullTransformerBatch], None], |
| 47 | + max_batch_items: int, |
| 48 | +): |
| 49 | + return TransformerCustom( |
| 50 | + nlp.vocab, |
| 51 | + model, |
| 52 | + set_extra_annotations, |
| 53 | + max_batch_items=max_batch_items, |
| 54 | + name=name, |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +class TransformerCustom(Transformer): |
| 59 | + |
| 60 | + def from_disk( |
| 61 | + self, path: Union[str, Path], *, exclude: Iterable[str] = tuple() |
| 62 | + ) -> "TransformerCustom": |
| 63 | + |
| 64 | + def load_model(p): |
| 65 | + p = Path(p).absolute() |
| 66 | + tokenizer, transformer = huggingface_from_pretrained_custom( |
| 67 | + p, self.model.attrs["tokenizer_config"] |
| 68 | + ) |
| 69 | + self.model.attrs["tokenizer"] = tokenizer |
| 70 | + self.model.attrs["set_transformer"](self.model, transformer) |
| 71 | + |
| 72 | + deserialize = { |
| 73 | + "vocab": self.vocab.from_disk, |
| 74 | + "cfg": lambda p: self.cfg.update(deserialize_config(p)), |
| 75 | + "model": load_model, |
| 76 | + } |
| 77 | + util.from_disk(path, deserialize, exclude) |
| 78 | + return self |
0 commit comments