-
Notifications
You must be signed in to change notification settings - Fork 0
/
gen.py
59 lines (44 loc) · 1.43 KB
/
gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import codecs
import sys
import markovify
BATCH_SIZE = 5
DATASETS = ['cs', 'phil', 'wiwi']
def get_all_models(state_size):
return markovify.combine([get_model(state_size, ds) for ds in DATASETS])
def get_model(state_size, dataset):
units = 'data/{0}/units.txt'.format(dataset)
abstract_units = 'data/{0}/abstract_units.txt'.format(dataset)
#
with codecs.open(units, 'r', 'utf-8') as f:
text = f.read()
model1 = markovify.NewlineText(text, state_size=state_size)
#
with codecs.open(abstract_units, 'r', 'utf-8') as f:
text =f.read()
model2 = markovify.NewlineText(text, state_size=state_size)
#
model = markovify.combine([model1, model2], [ 1.5, 1 ])
return model
def main(state_size=1, dataset='phil'):
if dataset == 'ALL':
model = get_all_models(state_size)
else:
model = get_model(state_size, dataset)
for i in range(BATCH_SIZE):
print(model.make_sentence())
print("\n----------------\n")
for i in range(BATCH_SIZE):
print(model.make_short_sentence(140))
print("\n----------------\n")
try:
for i in range(BATCH_SIZE):
print(model.make_sentence_with_start("Die"))
except KeyError:
pass
if __name__ == '__main__':
kwargs = {}
if len(sys.argv) > 1:
kwargs['state_size'] = int(sys.argv[1])
if len(sys.argv) > 2:
kwargs['dataset'] = sys.argv[2]
main(**kwargs)