Skip to content

Commit 46e2232

Browse files
authored
Merge pull request #52 from Juanets/interactivity
Add interactive mode
2 parents 7850b91 + cadb167 commit 46e2232

File tree

4 files changed

+84
-19
lines changed

4 files changed

+84
-19
lines changed

README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,23 @@ Kubernetes by Google’s Bern
5656
```
5757

5858
You can also train a new model, with support for word level embeddings and bidirectional RNN layers by adding `new_model=True` to any train function.
59+
60+
### Interactive mode
61+
It's also possible to get involved in how the output unfolds, step by step. Interactive mode will suggest you the *top N* options for the next char/word, and allows you to pick one.
62+
63+
Just pass `interactive=True` and `top=N`. N defaults to 3.
5964

65+
```python
66+
from textgenrnn import textgenrnn
67+
68+
textgen = textgenrnn()
69+
textgen.generate(interactive=True, top_n=5)
70+
```
71+
72+
![word_level_demo](/docs/word_level_demo.gif)
73+
74+
This can add a *human touch* to the output; it feels like you're the writer! ([reference](https://fivethirtyeight.com/features/some-like-it-bot/))
75+
6076
## Usage
6177

6278
textgenrnn can be installed [from pypi](https://pypi.python.org/pypi/textgenrnn) via `pip`:
@@ -107,8 +123,6 @@ Additionally, the retraining is done with a momentum-based optimizer and a linea
107123

108124
* A way to visualize the attention-layer outputs to see how the network "learns."
109125

110-
* Supervised text generation mode: allow the model to present the top *n* options and user select the next char/word ([reference](https://fivethirtyeight.com/features/some-like-it-bot/))
111-
112126
* A mode to allow the model architecture to be used for chatbot conversations (may be released as a separate project)
113127

114128
* More depth toward context (positional context + allowing multiple context labels)

docs/word_level_demo.gif

47.5 KB
Loading

textgenrnn/textgenrnn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ def __init__(self, weights_path=None,
6666
self.indices_char = dict((self.vocab[c], c) for c in self.vocab)
6767

6868
def generate(self, n=1, return_as_list=False, prefix=None,
69-
temperature=0.5, max_gen_length=300):
69+
temperature=0.5, max_gen_length=300, interactive=False,
70+
top_n=3):
7071
gen_texts = []
7172
for _ in range(n):
7273
gen_text = textgenrnn_generate(self.model,
@@ -79,7 +80,9 @@ def generate(self, n=1, return_as_list=False, prefix=None,
7980
self.config['word_level'],
8081
self.config.get(
8182
'single_text', False),
82-
max_gen_length)
83+
max_gen_length,
84+
interactive,
85+
top_n)
8386
if not return_as_list:
8487
print("{}\n".format(gen_text))
8588
gen_texts.append(gen_text)

textgenrnn/utils.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212

1313

14-
def textgenrnn_sample(preds, temperature):
14+
def textgenrnn_sample(preds, temperature, interactive=False, top_n=3):
1515
'''
1616
Samples predicted probabilities of the next character to allow
1717
for the network to show "creativity."
@@ -26,12 +26,18 @@ def textgenrnn_sample(preds, temperature):
2626
exp_preds = np.exp(preds)
2727
preds = exp_preds / np.sum(exp_preds)
2828
probas = np.random.multinomial(1, preds, 1)
29-
index = np.argmax(probas)
30-
31-
# prevent function from being able to choose 0 (placeholder)
32-
# choose 2nd best index from preds
33-
if index == 0:
34-
index = np.argsort(preds)[-2]
29+
30+
if not interactive:
31+
index = np.argmax(probas)
32+
33+
# prevent function from being able to choose 0 (placeholder)
34+
# choose 2nd best index from preds
35+
if index == 0:
36+
index = np.argsort(preds)[-2]
37+
else:
38+
# return list of top N chars/words
39+
# descending order, based on probability
40+
index = (-preds).argsort()[:top_n]
3541

3642
return index
3743

@@ -41,11 +47,15 @@ def textgenrnn_generate(model, vocab,
4147
maxlen=40, meta_token='<s>',
4248
word_level=False,
4349
single_text=False,
44-
max_gen_length=300):
50+
max_gen_length=300,
51+
interactive=False,
52+
top_n=3):
4553
'''
4654
Generates and returns a single text.
4755
'''
4856

57+
collapse_char = ' ' if word_level else ''
58+
4959
# If generating word level, must add spaces around each punctuation.
5060
# https://stackoverflow.com/a/3645946/9314418
5161
if word_level and prefix:
@@ -72,15 +82,53 @@ def textgenrnn_generate(model, vocab,
7282

7383
while next_char != meta_token and len(text) < max_gen_length:
7484
encoded_text = textgenrnn_encode_sequence(text[-maxlen:],
75-
vocab, maxlen)
85+
vocab, maxlen)
7686
next_temperature = temperature[(len(text) - 1) % len(temperature)]
77-
next_index = textgenrnn_sample(
78-
model.predict(encoded_text, batch_size=1)[0],
79-
next_temperature)
80-
next_char = indices_char[next_index]
81-
text += [next_char]
8287

83-
collapse_char = ' ' if word_level else ''
88+
if not interactive:
89+
# auto-generate text without user intervention
90+
next_index = textgenrnn_sample(
91+
model.predict(encoded_text, batch_size=1)[0],
92+
next_temperature)
93+
next_char = indices_char[next_index]
94+
text += [next_char]
95+
else:
96+
# ask user what the next char/word should be
97+
options_index = textgenrnn_sample(
98+
model.predict(encoded_text, batch_size=1)[0],
99+
next_temperature,
100+
interactive=interactive,
101+
top_n=top_n
102+
)
103+
options = [indices_char[idx] for idx in options_index]
104+
print('Controls:\n\ts: stop.\tx: backspace.\to: write your own.')
105+
print('\nOptions:')
106+
107+
for i, option in enumerate(options, 1):
108+
print('\t{}: {}'.format(i, option))
109+
110+
print('\nProgress: {}'.format(collapse_char.join(text)[3:]))
111+
print('\nYour choice?')
112+
user_input = input('> ')
113+
114+
try:
115+
user_input = int(user_input)
116+
next_char = options[user_input-1]
117+
text += [next_char]
118+
except ValueError:
119+
if user_input == 's':
120+
next_char = '<s>'
121+
text += [next_char]
122+
elif user_input == 'o':
123+
other = input('> ')
124+
text += [other]
125+
elif user_input == 'x':
126+
try:
127+
del text[-1]
128+
except IndexError:
129+
pass
130+
else:
131+
print('That\'s not an option!')
84132

85133
# if single text, ignore sequences generated w/ padding
86134
# if not single text, strip the <s> meta_tokens

0 commit comments

Comments
 (0)