Skip to content

Commit e4984a8

Browse files
add pipeline_component
1 parent ac6b5c6 commit e4984a8

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

config/ja_electra_base_parser_ner_accuracy.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ pooling = {"@layers":"reduce_mean.v1"}
6666
upstream = "*"
6767

6868
[components.transformer]
69-
factory = "transformer"
69+
factory = "transformer_custom"
7070
max_batch_items = 4096
7171
set_extra_annotations = {"@annotation_setters":"spacy-transformers.null_annotation_setter.v1"}
7272

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from typing import List, Callable, Iterable, Union
2+
from pathlib import Path
3+
4+
from spacy.language import Language
5+
from spacy.pipeline.pipe import deserialize_config
6+
from spacy.tokens import Doc
7+
from spacy import util
8+
from thinc.api import Model, Config
9+
10+
from spacy_transformers.data_classes import FullTransformerBatch
11+
from spacy_transformers.pipeline_component import Transformer, DOC_EXT_ATTR
12+
13+
from ginza_transformers.util import huggingface_from_pretrained_custom
14+
15+
16+
DEFAULT_CONFIG_STR = """
17+
[transformer_custom]
18+
max_batch_items = 4096
19+
20+
[transformer_custom.set_extra_annotations]
21+
@annotation_setters = "spacy-transformers.null_annotation_setter.v1"
22+
23+
[transformer_custom.model]
24+
@architectures = "ginza-transformers.TransformerModel.v1"
25+
name = "electra-base-ud-japanese-discriminator"
26+
tokenizer_config = {"use_fast": false, "tokenizer_class": "sudachitra.tokenization_electra_sudachipy.ElectraSudachipyTokenizer"}
27+
28+
[transformer_custom.model.get_spans]
29+
@span_getters = "spacy-transformers.strided_spans.v1"
30+
window = 128
31+
stride = 96
32+
"""
33+
34+
DEFAULT_CONFIG = Config().from_str(DEFAULT_CONFIG_STR)
35+
36+
37+
@Language.factory(
38+
"transformer_custom",
39+
assigns=[f"doc._.{DOC_EXT_ATTR}"],
40+
default_config=DEFAULT_CONFIG["transformer_custom"],
41+
)
42+
def make_transformer_custom(
43+
nlp: Language,
44+
name: str,
45+
model: Model[List[Doc], FullTransformerBatch],
46+
set_extra_annotations: Callable[[List[Doc], FullTransformerBatch], None],
47+
max_batch_items: int,
48+
):
49+
return TransformerCustom(
50+
nlp.vocab,
51+
model,
52+
set_extra_annotations,
53+
max_batch_items=max_batch_items,
54+
name=name,
55+
)
56+
57+
58+
class TransformerCustom(Transformer):
59+
60+
def from_disk(
61+
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
62+
) -> "TransformerCustom":
63+
64+
def load_model(p):
65+
p = Path(p).absolute()
66+
tokenizer, transformer = huggingface_from_pretrained_custom(
67+
p, self.model.attrs["tokenizer_config"]
68+
)
69+
self.model.attrs["tokenizer"] = tokenizer
70+
self.model.attrs["set_transformer"](self.model, transformer)
71+
72+
deserialize = {
73+
"vocab": self.vocab.from_disk,
74+
"cfg": lambda p: self.cfg.update(deserialize_config(p)),
75+
"model": load_model,
76+
}
77+
util.from_disk(path, deserialize, exclude)
78+
return self

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
author_email="ginza@megagon.ai",
77
description="ginza-transformers",
88
entry_points={
9+
"spacy_factories": [
10+
"transformer_custom = ginza_transformers.pipeline_component:make_transformer_custom",
11+
],
912
"spacy_architectures": [
1013
"ginza-transformers.TransformerModel.v1 = ginza_transformers:architectures.TransformerModelCustom",
1114
],
@@ -17,5 +20,5 @@
1720
name="ginza-transformers",
1821
packages=find_packages(include=["ginza_transformers", "ginza_transformers.layers"]),
1922
url="https://github.com/megagonlabs/ginza-transformers",
20-
version='0.1.1',
23+
version='0.2.0',
2124
)

0 commit comments

Comments
 (0)