Skip to content

Commit 6056e02

Browse files
committed
add plots
1 parent 1c432c9 commit 6056e02

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

ngram/plots.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
3+
import matplotlib.pyplot as plt
4+
import seaborn as sns
5+
6+
# https://seaborn.pydata.org/generated/seaborn.set_context.html
7+
# https://seaborn.pydata.org/generated/seaborn.set_style.html
8+
sns.set_style("white")
9+
sns.set_context("notebook", font_scale=1)
10+
11+
from ngram import MLENGram, AdditiveNGram, GoodTuringNGram
12+
13+
14+
def plot_count_models(GT, N):
15+
NC = GT._num_grams_with_count
16+
mod = GT._count_models[N]
17+
max_n = max(GT.counts[N].values())
18+
emp = [NC(n + 1, N) for n in range(max_n)]
19+
prd = [np.exp(mod.predict(np.array([n + 1]))) for n in range(max_n + 10)]
20+
plt.scatter(range(max_n), emp, c="r", label="actual")
21+
plt.plot(range(max_n + 10), prd, "-", label="model")
22+
plt.ylim([-1, 100])
23+
plt.xlabel("Count ($r$)")
24+
plt.ylabel("Count-of-counts ($N_r$)")
25+
plt.legend()
26+
plt.savefig("test.png")
27+
plt.close()
28+
29+
30+
def compare_probs(fp, N):
31+
MLE = MLENGram(N, unk=False, filter_punctuation=False, filter_stopwords=False)
32+
MLE.train(fp, encoding="utf-8-sig")
33+
34+
add_y, mle_y, gtt_y = [], [], []
35+
addu_y, mleu_y, gttu_y = [], [], []
36+
seen = ("<bol>", "the")
37+
unseen = ("<bol>", "asdf")
38+
39+
GTT = GoodTuringNGram(
40+
N, conf=1.96, unk=False, filter_stopwords=False, filter_punctuation=False
41+
)
42+
GTT.train(fp, encoding="utf-8-sig")
43+
44+
gtt_prob = GTT.log_prob(seen, N)
45+
gtt_prob_u = GTT.log_prob(unseen, N)
46+
47+
for K in np.linspace(0, 10, 20):
48+
ADD = AdditiveNGram(
49+
N, K, unk=False, filter_punctuation=False, filter_stopwords=False
50+
)
51+
ADD.train(fp, encoding="utf-8-sig")
52+
53+
add_prob = ADD.log_prob(seen, N)
54+
mle_prob = MLE.log_prob(seen, N)
55+
56+
add_y.append(add_prob)
57+
mle_y.append(mle_prob)
58+
gtt_y.append(gtt_prob)
59+
60+
mle_prob_u = MLE.log_prob(unseen, N)
61+
add_prob_u = ADD.log_prob(unseen, N)
62+
63+
addu_y.append(add_prob_u)
64+
mleu_y.append(mle_prob_u)
65+
gttu_y.append(gtt_prob_u)
66+
67+
plt.plot(np.linspace(0, 10, 20), add_y, label="Additive (seen ngram)")
68+
plt.plot(np.linspace(0, 10, 20), addu_y, label="Additive (unseen ngram)")
69+
# plt.plot(np.linspace(0, 10, 20), gtt_y, label="Good-Turing (seen ngram)")
70+
# plt.plot(np.linspace(0, 10, 20), gttu_y, label="Good-Turing (unseen ngram)")
71+
plt.plot(np.linspace(0, 10, 20), mle_y, "--", label="MLE (seen ngram)")
72+
plt.xlabel("K")
73+
plt.ylabel("log P(sequence)")
74+
plt.legend()
75+
plt.savefig("img/add_smooth.png")
76+
plt.close("all")
77+
78+
79+
def plot_gt_freqs(fp):
80+
"""
81+
Draws a scatterplot of the empirical frequencies of the counted species
82+
versus their Simple Good Turing smoothed values, in rank order. Depends on
83+
pylab and matplotlib.
84+
"""
85+
MLE = MLENGram(1, filter_punctuation=False, filter_stopwords=False)
86+
MLE.train(fp, encoding="utf-8-sig")
87+
counts = dict(MLE.counts[1])
88+
89+
GT = GoodTuringNGram(1, filter_stopwords=False, filter_punctuation=False)
90+
GT.train(fp, encoding="utf-8-sig")
91+
92+
ADD = AdditiveNGram(1, 1, filter_punctuation=False, filter_stopwords=False)
93+
ADD.train(fp, encoding="utf-8-sig")
94+
95+
tot = float(sum(counts.values()))
96+
freqs = dict([(token, cnt / tot) for token, cnt in counts.items()])
97+
sgt_probs = dict([(tok, np.exp(GT.log_prob(tok, 1))) for tok in counts.keys()])
98+
as_probs = dict([(tok, np.exp(ADD.log_prob(tok, 1))) for tok in counts.keys()])
99+
100+
X, Y = np.arange(len(freqs)), sorted(freqs.values(), reverse=True)
101+
plt.loglog(X, Y, "k+", alpha=0.25, label="MLE")
102+
103+
X, Y = np.arange(len(sgt_probs)), sorted(sgt_probs.values(), reverse=True)
104+
plt.loglog(X, Y, "r+", alpha=0.25, label="simple Good-Turing")
105+
106+
X, Y = np.arange(len(as_probs)), sorted(as_probs.values(), reverse=True)
107+
plt.loglog(X, Y, "b+", alpha=0.25, label="Laplace smoothing")
108+
109+
plt.xlabel("Rank")
110+
plt.ylabel("Probability")
111+
plt.legend()
112+
plt.tight_layout()
113+
plt.savefig("img/rank_probs.png")
114+
plt.close("all")

0 commit comments

Comments
 (0)