Skip to content

Commit e6cf942

Browse files
committed
add ngram models
1 parent 6cea0eb commit e6cf942

File tree

4 files changed

+581
-0
lines changed

4 files changed

+581
-0
lines changed

ngram/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# N-Gram Sequence Models
2+
The `ngram.py` module implements [n-gram models](https://en.wikipedia.org/wiki/N-gram) with different smoothing techniques:
3+
4+
- Maximum likelihood (no smoothing)
5+
- [Additive smoothing](https://en.wikipedia.org/wiki/Additive_smoothing) (incl.
6+
Laplace smoothing, expected likelihood estimation, etc.)
7+

ngram/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ngram import *

ngram/ngram.py

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,340 @@
1+
import sys
2+
import textwrap
3+
from abc import ABC, abstractmethod
4+
from collections import Counter
5+
6+
sys.path.append("..")
7+
8+
import numpy as np
9+
10+
from preprocessing.nlp import tokenize_words, ngrams
11+
12+
13+
class NGramBase(ABC):
14+
def __init__(self, N, unk=True, filter_stopwords=True, filter_punctuation=True):
15+
"""
16+
A simple N-gram language model.
17+
18+
NB. This is not optimized code and will be slow for large corpora. To
19+
see how industry-scale NGram models are handled, see the SRLIM-format:
20+
21+
http://www.speech.sri.com/projects/srilm/
22+
"""
23+
self.N = N
24+
self.unk = unk
25+
self.filter_stopwords = filter_stopwords
26+
self.filter_punctuation = filter_punctuation
27+
28+
self.hyperparameters = {
29+
"N": N,
30+
"unk": unk,
31+
"filter_stopwords": filter_stopwords,
32+
"filter_punctuation": filter_punctuation,
33+
}
34+
35+
super().__init__()
36+
37+
def train(self, corpus_fp, vocab=None, encoding=None):
38+
"""
39+
Compile the n-gram counts for the text(s) in `corpus_fp`. Upon
40+
completion the `self.counts` attribute will store dictionaries of the
41+
N, N-1, ..., 1-gram counts.
42+
43+
Parameters
44+
----------
45+
corpus_fp : str
46+
The path to a newline-separated text corpus file
47+
vocab : `preprocessing.nlp.Vocabulary` instance (default: None)
48+
If not `None`, only the words in `vocab` will be used to construct
49+
the language model
50+
encoding : str (default: None)
51+
Specifies the text encoding for corpus. Common entries are 'utf-8',
52+
'utf-8-sig', 'utf-16'.
53+
"""
54+
H = self.hyperparameters
55+
grams = {N: [] for N in range(1, self.N + 1)}
56+
counts = {N: Counter() for N in range(1, self.N + 1)}
57+
filter_punc, filter_stop = H["filter_punctuation"], H["filter_stopwords"]
58+
59+
_n_words = 0
60+
tokens = set(["<unk>"])
61+
bol, eol = ["<bol>"], ["<eol>"]
62+
63+
with open(corpus_fp, "r", encoding=encoding) as text:
64+
for line in text:
65+
words = tokenize_words(line, filter_punc, filter_stop)
66+
67+
if vocab is not None:
68+
words = vocab.filter(words, H["unk"])
69+
70+
if len(words) == 0:
71+
continue
72+
73+
_n_words += len(words)
74+
tokens.update(words)
75+
76+
# calculate n, n-1, ... 1-grams
77+
for N in range(1, self.N + 1):
78+
words_padded = bol * max(1, N - 1) + words + eol * max(1, N - 1)
79+
grams[N].extend(ngrams(words_padded, N))
80+
81+
for N in counts.keys():
82+
counts[N].update(grams[N])
83+
84+
n_words = {N: np.sum(list(counts[N].values())) for N in range(1, self.N + 1)}
85+
n_words[1] = _n_words
86+
87+
n_tokens = {N: len(counts[N]) for N in range(2, self.N + 1)}
88+
n_tokens[1] = len(vocab) if vocab is not None else len(tokens)
89+
90+
self.counts = counts
91+
self.n_words = n_words
92+
self.n_tokens = n_tokens
93+
94+
def completions(self, words, N):
95+
"""
96+
Return the distribution over proposed next words under the `N`-gram
97+
language model.
98+
99+
Parameters
100+
----------
101+
words : list or tuple of strings
102+
The initial sequence of words
103+
N : int
104+
The gram-size of the language model to use to generate completions
105+
106+
Returns
107+
-------
108+
probs : list of (word, log_prob) tuples
109+
The list of possible next words and their log probabilities under
110+
the `N`-gram language model (unsorted)
111+
"""
112+
N = min(N, len(words) + 1)
113+
assert N in self.counts, "You do not have counts for {}-grams".format(N)
114+
assert len(words) >= N - 1, "`words` must have at least {} words".format(N - 1)
115+
116+
probs = []
117+
base = tuple(w.lower() for w in words[-N + 1 :])
118+
for k in self.counts[N].keys():
119+
if k[:-1] == base:
120+
c_prob = self._log_ngram_prob(base + k[-1:])
121+
probs.append((k[-1], c_prob))
122+
return probs
123+
124+
def generate(self, N, seed_words=["<bol>"], n_sentences=5):
125+
"""
126+
Use the `N`-gram language model to generate sentences.
127+
128+
Parameters
129+
----------
130+
N : int
131+
The gram-size of the model to generate from
132+
seed_words : list of strs (default: ["<bol>"])
133+
A list of seed words to use to condition the initial sentence
134+
generation
135+
sentences : int (default : 50)
136+
The number of sentences to generate from the `N`-gram model
137+
138+
Returns
139+
-------
140+
sentences : str
141+
Samples from the `N`-gram model, joined by white spaces, with
142+
individual sentences separated by newlines.
143+
"""
144+
counter = 0
145+
sentences = []
146+
words = seed_words.copy()
147+
while counter < n_sentences:
148+
nextw, probs = zip(*self.completions(words, N))
149+
next_word = np.random.choice(nextw, p=np.exp(probs))
150+
151+
# if we reach the end of a sentence, save it and start a new one
152+
if next_word == "<eol>":
153+
S = " ".join([w for w in words if w != "<bol>"])
154+
S = textwrap.fill(S, 90, initial_indent="", subsequent_indent=" ")
155+
print(S)
156+
sentences.append(words)
157+
words = seed_words.copy()
158+
counter += 1
159+
continue
160+
161+
words.append(next_word)
162+
return sentences
163+
164+
def _log_prob(self, words, N):
165+
"""Calculate the log probability of a sequence of words under the `N`-gram model"""
166+
assert N in self.counts, "You do not have counts for {}-grams".format(N)
167+
168+
if N > len(words):
169+
err = "Not enough words for a gram-size of {}: {}".format(N, len(words))
170+
raise ValueError(err)
171+
172+
total_prob = 0
173+
for ngram in ngrams(words, N):
174+
total_prob += self._log_ngram_prob(ngram)
175+
return total_prob
176+
177+
def _n_completions(self, words, N):
178+
"""
179+
Return the number of unique word tokens that could follow the sequence
180+
`words` under the *unsmoothed* `N`-gram language model.
181+
"""
182+
assert N in self.counts, "You do not have counts for {}-grams".format(N)
183+
assert len(words) <= N - 1, "Need > {} words to use {}-grams".format(N - 2, N)
184+
185+
if isinstance(words, list):
186+
words = tuple(words)
187+
188+
base = words[-N + 1 :]
189+
return len([k[-1] for k in self.counts[N].keys() if k[:-1] == base])
190+
191+
def _num_grams_with_count(self, C, N):
192+
"""
193+
Return the number of unique `N`-gram tokens that occur exactly `C`
194+
times
195+
"""
196+
assert C > 0
197+
assert N in self.counts, "You do not have counts for {}-grams".format(N)
198+
# cache count values for future calls
199+
if not hasattr(self, "_NC"):
200+
self._NC = {N: {} for N in range(1, self.N + 1)}
201+
if C not in self._NC[N]:
202+
self._NC[N][C] = len([k for k, v in self.counts[N].items() if v == C])
203+
return self._NC[N][C]
204+
205+
@abstractmethod
206+
def log_prob(self, words, N):
207+
raise NotImplementedError
208+
209+
@abstractmethod
210+
def _log_ngram_prob(self, ngram):
211+
raise NotImplementedError
212+
213+
214+
class MLENGram(NGramBase):
215+
def __init__(self, N, unk=True, filter_stopwords=True, filter_punctuation=True):
216+
"""
217+
A simple, unsmoothed N-gram model.
218+
219+
Parameters
220+
----------
221+
N : int
222+
The maximum length (in words) of the context-window to use in the
223+
langauge model. Model will compute all n-grams from 1, ..., N
224+
unk : bool (default: True)
225+
Whether to include the <unk> (unknown) token in the LM
226+
filter_stopwords : bool (default: True)
227+
Whether to remove stopwords before training
228+
filter_punctuation : bool (default: True)
229+
Whether to remove punctuation before training
230+
"""
231+
super().__init__(N, unk, filter_stopwords, filter_punctuation)
232+
self.hyperparameters["id"] = "MLENGram"
233+
234+
def log_prob(self, words, N):
235+
"""
236+
Compute the log probability of a sequence of words under the
237+
unsmoothed, maximum-likelihood `N`-gram language model. For a bigram,
238+
this amounts to:
239+
240+
Parameters
241+
----------
242+
words : list of strings
243+
A sequence of words
244+
N : int
245+
The gram-size of the language model to use when calculating the log
246+
probabilities of the sequence
247+
248+
Returns
249+
-------
250+
total_prob : float
251+
The total log-probability of the sequence `words` under the
252+
`N`-gram language model
253+
"""
254+
return self._log_prob(words, N)
255+
256+
def _log_ngram_prob(self, ngram):
257+
"""Return the unsmoothed log probability of the ngram"""
258+
N = len(ngram)
259+
num = self.counts[N][ngram]
260+
den = self.counts[N - 1][ngram[:-1]] if N > 1 else self.n_words[1]
261+
return np.log(num) - np.log(den) if (den > 0 and num > 0) else -np.inf
262+
263+
264+
class AdditiveNGram(NGramBase):
265+
def __init__(
266+
self, N, K=1, unk=True, filter_stopwords=True, filter_punctuation=True
267+
):
268+
"""
269+
An N-Gram model with smoothed probabilities calculated via additive /
270+
Lidstone smoothing. The resulting estimates correspond to the expected
271+
value of the posterior, p(ngram_prob | counts), when using a symmetric
272+
Dirichlet prior on counts with parameter K.
273+
274+
Parameters
275+
----------
276+
N : int
277+
The maximum length (in words) of the context-window to use in the
278+
langauge model. Model will compute all n-grams from 1, ..., N
279+
K : float (default: 1)
280+
The pseudocount to add to each observation. Larger values allocate
281+
more probability toward unseen events. When K = 1, the model is
282+
known as Laplace smoothing. When K = 0.5, the model is known as
283+
expected likelihood estimation (ELE) or the Jeffreys-Perks law
284+
unk : bool (default: True)
285+
Whether to include the <unk> (unknown) token in the LM
286+
filter_stopwords : bool (default: True)
287+
Whether to remove stopwords before training
288+
filter_punctuation : bool (default: True)
289+
Whether to remove punctuation before training
290+
"""
291+
super().__init__(N, unk, filter_stopwords, filter_punctuation)
292+
self.hyperparameters["id"] = "AdditiveNGram"
293+
self.hyperparameters["K"] = K
294+
295+
def log_prob(self, words, N):
296+
"""
297+
Compute the smoothed log probability of a sequence of words under the
298+
`N`-gram language model with additive smoothing. For a bigram, this
299+
amounts to:
300+
301+
P(w_i | w_{i-1}) = (A + K) / (B + K * V)
302+
303+
where
304+
305+
A = Count(w_{i-1}, w_i)
306+
B = sum_j Count(w_{i-1}, w_j)
307+
V = |{ w_j : Count(w_{i-1}, w_j) > 0 }|
308+
309+
This is equivalent to pretending we've seen every possible N-gram
310+
sequence at least `K` times. This can be problematic, as it:
311+
- Treats each predicted word in the same way (uniform prior counts)
312+
- Can assign too much probability mass to unseen N-grams (too aggressive)
313+
314+
Parameters
315+
----------
316+
words : list of strings
317+
A sequence of words
318+
N : int
319+
The gram-size of the language model to use when calculating the log
320+
probabilities of the sequence
321+
322+
Returns
323+
-------
324+
total_prob : float
325+
The total log-probability of the sequence `words` under the
326+
`N`-gram language model
327+
"""
328+
return self._log_prob(words, N)
329+
330+
def _log_ngram_prob(self, ngram):
331+
"""Return the smoothed log probability of the ngram"""
332+
N = len(ngram)
333+
K = self.hyperparameters["K"]
334+
counts, n_words, n_tokens = self.counts, self.n_words[1], self.n_tokens[1]
335+
336+
ctx = ngram[:-1]
337+
ctx_count = counts[N - 1][ctx] if N > 1 else n_words
338+
num = counts[N][ngram] + K
339+
den = ctx_count + K * n_tokens
340+
return np.log(num / den) if den != 0 else -np.inf

0 commit comments

Comments
 (0)