16
16
# Modifications copyright (C) 2019 Texar
17
17
# ==============================================================================
18
18
"""
19
- Implemetation of beam seach with penalties.
20
- Adapted from tensor2tensor repositor .
19
+ Implementation of beam search with penalties.
20
+ Adapted from tensor2tensor repository .
21
21
"""
22
22
23
23
from __future__ import absolute_import
32
32
# Default value for INF
33
33
INF = 1. * 1e7
34
34
35
+
35
36
def _merge_beam_dim (tensor ):
36
37
"""Reshapes first two dimensions in to single dimension.
37
38
@@ -41,6 +42,8 @@ def _merge_beam_dim(tensor):
41
42
Returns:
42
43
Reshaped tensor of shape [A*B, ...]
43
44
"""
45
+ if not isinstance (tensor , tf .Tensor ) or not tensor .get_shape ().as_list ():
46
+ return tensor
44
47
shape = shape_list (tensor )
45
48
shape [0 ] *= shape [1 ] # batch -> batch * beam_size
46
49
shape .pop (1 ) # Remove beam dim
@@ -58,6 +61,8 @@ def _unmerge_beam_dim(tensor, batch_size, beam_size):
58
61
Returns:
59
62
Reshaped tensor of shape [batch_size, beam_size, ...]
60
63
"""
64
+ if not isinstance (tensor , tf .Tensor ) or not tensor .get_shape ().as_list ():
65
+ return tensor
61
66
shape = shape_list (tensor )
62
67
new_shape = [batch_size ] + [beam_size ] + shape [1 :]
63
68
return tf .reshape (tensor , new_shape )
@@ -73,6 +78,8 @@ def _expand_to_beam_size(tensor, beam_size):
73
78
Returns:
74
79
Tiled tensor [batch_size, beam_size, ...]
75
80
"""
81
+ if not isinstance (tensor , tf .Tensor ) or not tensor .get_shape ().as_list ():
82
+ return tensor
76
83
tensor = tf .expand_dims (tensor , axis = 1 )
77
84
tile_dims = [1 ] * tensor .shape .ndims
78
85
tile_dims [1 ] = beam_size
@@ -173,6 +180,9 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
173
180
# operations with tfdbg. Clients can capture these tensors by watching
174
181
# these node names.
175
182
def gather (tensor , name ):
183
+ if not isinstance (tensor ,
184
+ tf .Tensor ) or not tensor .get_shape ().as_list ():
185
+ return tensor
176
186
return tf .gather_nd (tensor , top_coordinates , name = (prefix + name ))
177
187
topk_seq = gather (sequences , "_topk_seq" )
178
188
topk_flags = gather (flags , "_topk_flags" )
@@ -196,7 +206,7 @@ def beam_search(symbols_to_logits_fn,
196
206
stop_early = True ):
197
207
"""Beam search with length penalties.
198
208
199
- Requires a function that can take the currently decoded sybmols and
209
+ Requires a function that can take the currently decoded symbols and
200
210
return the logits for the next symbol. The implementation is inspired
201
211
by https://arxiv.org/abs/1609.08144.
202
212
@@ -255,11 +265,11 @@ def beam_search(symbols_to_logits_fn,
255
265
# Expand each batch and state to beam_size
256
266
alive_seq = _expand_to_beam_size (initial_ids , beam_size )
257
267
alive_seq = tf .expand_dims (alive_seq , axis = 2 )
258
- #(batch_size, beam_size, 1)
268
+
269
+ # (batch_size, beam_size, 1)
259
270
if states :
260
271
states = nest .map_structure (
261
- lambda state : _expand_to_beam_size (state , beam_size ),
262
- states )
272
+ lambda state : _expand_to_beam_size (state , beam_size ), states )
263
273
else :
264
274
states = {}
265
275
@@ -384,7 +394,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
384
394
if states :
385
395
flat_states = nest .map_structure (_merge_beam_dim , states )
386
396
flat_logits , flat_states = symbols_to_logits_fn (flat_ids , i ,
387
- flat_states )
397
+ flat_states )
388
398
states = nest .map_structure (
389
399
lambda t : _unmerge_beam_dim (t , batch_size , beam_size ),
390
400
flat_states )
@@ -435,20 +445,19 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
435
445
topk_seq = tf .gather_nd (alive_seq , topk_coordinates )
436
446
if states :
437
447
states = nest .map_structure (
438
- lambda state : tf .gather_nd (state , topk_coordinates ),
439
- states )
448
+ lambda state : tf .gather_nd (state , topk_coordinates ), states )
440
449
441
450
# Append the most probable alive
442
451
topk_seq = tf .concat ([topk_seq , tf .expand_dims (topk_ids , axis = 2 )],
443
- axis = 2 )
452
+ axis = 2 )
444
453
445
454
topk_finished = tf .equal (topk_ids , eos_id )
446
455
447
456
return topk_seq , topk_log_probs , topk_scores , topk_finished , states
448
457
449
458
def inner_loop (i , alive_seq , alive_log_probs , finished_seq ,
450
- finished_scores , finished_flags , states ):
451
- """Inner beam seach loop.
459
+ finished_scores , finished_flags , states ):
460
+ """Inner beam search loop.
452
461
453
462
There are three groups of tensors, alive, finished, and topk.
454
463
The alive group contains information about the current alive
0 commit comments