forked from flairNLP/flair
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_visual.py
77 lines (61 loc) · 2.45 KB
/
test_visual.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from flair.data import Sentence, Span, Token
from flair.embeddings import FlairEmbeddings
from flair.visual import Highlighter
from flair.visual.ner_html import HTML_PAGE, PARAGRAPH, TAGGED_ENTITY, render_ner_html
from flair.visual.training_curves import Plotter
def test_highlighter(resources_path):
with (resources_path / "visual/snippet.txt").open() as f:
sentences = [x for x in f.read().split("\n") if x]
embeddings = FlairEmbeddings("news-forward")
features = embeddings.lm.get_representation(sentences[0], "", "").squeeze()
Highlighter().highlight_selection(
features,
sentences[0],
n=1000,
file_=str(resources_path / "visual/highligh.html"),
)
# clean up directory
(resources_path / "visual/highligh.html").unlink()
def test_plotting_training_curves_and_weights(resources_path):
plotter = Plotter()
plotter.plot_training_curves(resources_path / "visual/loss.tsv")
plotter.plot_weights(resources_path / "visual/weights.txt")
# clean up directory
(resources_path / "visual/weights.png").unlink()
(resources_path / "visual/training.png").unlink()
def mock_ner_span(text, tag, start, end):
span = Span([]).set_label("class", tag)
span.start_pos = start
span.end_pos = end
span.tokens = [Token(text[start:end])]
return span
def test_html_rendering():
text = (
"Boris Johnson has been elected new Conservative leader in "
"a ballot of party members and will become the "
"next UK prime minister. &"
)
sentence = Sentence(text)
print(sentence[0:2].add_label("ner", "PER"))
print(sentence[6:7].add_label("ner", "MISC"))
print(sentence[19:20].add_label("ner", "LOC"))
colors = {
"PER": "#F7FF53",
"ORG": "#E8902E",
"LOC": "yellow",
"MISC": "#4647EB",
"O": "#ddd",
}
actual = render_ner_html([sentence], colors=colors)
expected_res = HTML_PAGE.format(
text=PARAGRAPH.format(
sentence=TAGGED_ENTITY.format(color="#F7FF53", entity="Boris Johnson", label="PER")
+ " has been elected new "
+ TAGGED_ENTITY.format(color="#4647EB", entity="Conservative", label="MISC")
+ " leader in a ballot of party members and will become the next "
+ TAGGED_ENTITY.format(color="yellow", entity="UK", label="LOC")
+ " prime minister. &"
),
title="Flair",
)
assert expected_res == actual