-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathquant_flair_model.py
46 lines (40 loc) · 1.52 KB
/
quant_flair_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import Union
from pathlib import Path
import torch
import flair
class QuantizableLanguageModel(flair.models.LanguageModel):
def forward(self, input, hidden, ordered_sequence_lengths=None):
encoded = self.encoder(input)
emb = self.drop(encoded)
if hasattr(self.rnn, "flatten_parameters"):
self.rnn.flatten_parameters()
output, hidden = self.rnn(emb, hidden)
if self.proj is not None:
output = self.proj(output)
output = self.drop(output)
decoded = self.decoder(
output.view(output.size(0) * output.size(1), output.size(2))
)
return (
decoded.view(output.size(0), output.size(1), decoded.size(1)),
output,
hidden,
)
@classmethod
def load_language_model(cls, model_file: Union[Path, str]):
state = torch.load(str(model_file), map_location=flair.device)
document_delimiter = state["document_delimiter"] if "document_delimiter" in state else '\n'
model = cls(
dictionary=state["dictionary"],
is_forward_lm=state["is_forward_lm"],
hidden_size=state["hidden_size"],
nlayers=state["nlayers"],
embedding_size=state["embedding_size"],
nout=state["nout"],
document_delimiter=document_delimiter,
dropout=state["dropout"],
)
model.load_state_dict(state["state_dict"])
model.eval()
model.to(flair.device)
return model