Skip to content

Commit b297a9e

Browse files
Merge pull request #7 from megagonlabs/feature/setup_model
add setup_model.py
2 parents 9ef017e + f38d44e commit b297a9e

File tree

2 files changed

+97
-17
lines changed

2 files changed

+97
-17
lines changed

ginza_transformers/layers/hf_shim_custom.py

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import sys
2-
from typing import Any
32
from io import BytesIO
43
from pathlib import Path
54
import srsly
@@ -15,6 +14,17 @@
1514
from transformers import AutoModel, AutoConfig, AutoTokenizer
1615

1716

17+
def override_hf_shims_to_bytes():
18+
assert hf_shim.HFShim.to_bytes is not HFShimCustom.to_bytes
19+
origin = hf_shim.HFShim.to_bytes
20+
hf_shim.HFShim.to_bytes = HFShimCustom.to_bytes
21+
return origin
22+
23+
def recover_hf_shims_to_bytes(origin):
24+
assert hf_shim.HFShim.to_bytes is HFShimCustom.to_bytes
25+
hf_shim.HFShim.to_bytes = origin
26+
27+
1828
def override_hf_shims_from_bytes():
1929
assert hf_shim.HFShim.from_bytes is not HFShimCustom.from_bytes
2030
origin = hf_shim.HFShim.from_bytes
@@ -28,6 +38,44 @@ def recover_hf_shims_from_bytes(origin):
2838

2939
class HFShimCustom(HFShim):
3040

41+
def to_bytes(self):
42+
config = {}
43+
tok_dict = {}
44+
# weights_bytes = {}
45+
tok_cfg = {}
46+
trf_cfg = {}
47+
hf_model = self._hfmodel
48+
if hf_model.transformer is not None:
49+
tok_dict = {}
50+
config = hf_model.transformer.config.to_dict()
51+
tokenizer = hf_model.tokenizer
52+
with make_tempdir() as temp_dir:
53+
if hasattr(tokenizer, "vocab_file"):
54+
vocab_file_name = tokenizer.vocab_files_names["vocab_file"]
55+
vocab_file_path = str((temp_dir / vocab_file_name).absolute())
56+
with open(vocab_file_path, "wb") as fileh:
57+
fileh.write(hf_model.vocab_file_contents)
58+
tokenizer.vocab_file = vocab_file_path
59+
tokenizer.save_pretrained(str(temp_dir.absolute()))
60+
for x in temp_dir.glob("**/*"):
61+
if x.is_file():
62+
tok_dict[x.name] = x.read_bytes()
63+
filelike = BytesIO()
64+
torch.save(self._model.state_dict(), filelike)
65+
filelike.seek(0)
66+
# weights_bytes = filelike.getvalue()
67+
else:
68+
tok_cfg = hf_model._init_tokenizer_config
69+
trf_cfg = hf_model._init_transformer_config
70+
msg = {
71+
"config": config,
72+
# "state": weights_bytes,
73+
"tokenizer": tok_dict,
74+
"_init_tokenizer_config": tok_cfg,
75+
"_init_transformer_config": trf_cfg,
76+
}
77+
return srsly.msgpack_dumps(msg)
78+
3179
def from_bytes(self, bytes_data):
3280
msg = srsly.msgpack_loads(bytes_data)
3381
config_dict = msg["config"]
@@ -62,34 +110,35 @@ def from_bytes(self, bytes_data):
62110
with open(vocab_file_path, "rb") as fileh:
63111
vocab_file_contents = fileh.read()
64112

65-
try:
113+
ops = get_current_ops()
114+
if ops.device_type == "cpu":
115+
map_location = "cpu"
116+
else: # pragma: no cover
117+
device_id = torch.cuda.current_device()
118+
map_location = f"cuda:{device_id}"
119+
120+
if "state" in msg:
66121
transformer = AutoModel.from_config(config)
67-
except OSError as e:
122+
filelike = BytesIO(msg["state"])
123+
filelike.seek(0)
124+
transformer.load_state_dict(torch.load(filelike, map_location=map_location))
125+
else:
68126
try:
69-
transformer = AutoModel.from_pretrained(config["_name_or_path"], local_files_only=True)
127+
transformer = AutoModel.from_pretrained(config._name_or_path, local_files_only=True)
70128
except OSError as e2:
71-
print("trying to download model from huggingface hub:", config["_name_or_path"], "...", file=sys.stderr)
72-
transformer = AutoModel.from_pretrained(config["_name_or_path"])
129+
print("trying to download model from huggingface hub:", config._name_or_path, "...", file=sys.stderr)
130+
transformer = AutoModel.from_pretrained(config._name_or_path)
73131
print("succeded", file=sys.stderr)
74132

133+
transformer.to(map_location)
134+
self._model = transformer
75135
self._hfmodel = HFObjects(
76136
tokenizer,
77137
transformer,
78138
vocab_file_contents,
79139
SimpleFrozenDict(),
80140
SimpleFrozenDict(),
81141
)
82-
self._model = transformer
83-
filelike = BytesIO(msg["state"])
84-
filelike.seek(0)
85-
ops = get_current_ops()
86-
if ops.device_type == "cpu":
87-
map_location = "cpu"
88-
else: # pragma: no cover
89-
device_id = torch.cuda.current_device()
90-
map_location = f"cuda:{device_id}"
91-
self._model.load_state_dict(torch.load(filelike, map_location=map_location))
92-
self._model.to(map_location)
93142
else:
94143
self._hfmodel = HFObjects(
95144
None,

ginza_transformers/setup_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import sys
2+
3+
import spacy
4+
5+
from .layers.hf_shim_custom import override_hf_shims_to_bytes, recover_hf_shims_to_bytes
6+
7+
8+
def main():
9+
org_spacy_model_path = sys.argv[1]
10+
dst_spacy_model_path = sys.argv[2]
11+
transformers_model_name = sys.argv[3]
12+
13+
nlp = spacy.load(org_spacy_model_path)
14+
transformer = nlp.get_pipe("transformer")
15+
for i, node in enumerate(transformer.model.walk()):
16+
if node.shims:
17+
break
18+
else:
19+
assert False
20+
node.shims[0]._hfmodel.transformer.config._name_or_path = transformers_model_name
21+
node.shims[0]._hfmodel.tokenizer.save_pretrained(transformers_model_name)
22+
node.shims[0]._hfmodel.transformer.save_pretrained(transformers_model_name)
23+
override_hf_shims_to_bytes()
24+
try:
25+
origin = nlp.to_disk(dst_spacy_model_path)
26+
finally:
27+
recover_hf_shims_to_bytes(origin)
28+
29+
30+
if __name__ == "__main__":
31+
main()

0 commit comments

Comments
 (0)