Skip to content

Commit a2dee78

Browse files
committed
Resolve comments
1 parent 2307d99 commit a2dee78

File tree

8 files changed

+45
-41
lines changed

8 files changed

+45
-41
lines changed

examples/seq2seq_exposure_bias/interpolation_decoder.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ def step(self, time, inputs, state, name=None):
113113
logits, sample_ids, wrapper_outputs,
114114
attention_scores, attention_context)
115115

116-
return (outputs, sample_ids, logits, wrapper_state)
116+
return (outputs, wrapper_state)
117117

118-
def next_inputs(self, sample_ids, time, outputs, state):
118+
def next_inputs(self, time, outputs, state):
119119
(finished, next_inputs, next_state) = self._helper.next_inputs(
120120
time=time,
121-
outputs=outputs,
121+
outputs=outputs.logits,
122122
state=[state[0], state],
123-
sample_ids=sample_ids)
123+
sample_ids=outputs.sample_id)
124124
return (finished, next_inputs, next_state)

texar/tf/modules/decoders/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
# pylint: disable=wildcard-import
2323

24+
from texar.tf.modules.decoders.beam_search_decode import *
2425
from texar.tf.modules.decoders.rnn_decoder_base import *
2526
from texar.tf.modules.decoders.rnn_decoders import *
2627
from texar.tf.modules.decoders.tf_helpers import *
2728
from texar.tf.modules.decoders.rnn_decoder_helpers import *
2829
from texar.tf.modules.decoders.transformer_decoders import *
29-
from texar.tf.modules.decoders.beam_search_decode import *

texar/tf/modules/decoders/rnn_decoders.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,14 @@ def step(self, time, inputs, state, name=None):
257257
sample_ids = self._helper.sample(
258258
time=time, outputs=logits, state=cell_state)
259259
outputs = BasicRNNDecoderOutput(logits, sample_ids, cell_outputs)
260-
return outputs, sample_ids, logits, cell_state
260+
return outputs, cell_state
261261

262-
def next_inputs(self, sample_ids, time, logits, state):
262+
def next_inputs(self, time, outputs, state):
263263
(finished, next_inputs, next_state) = self._helper.next_inputs(
264264
time=time,
265-
outputs=logits,
265+
outputs=outputs.logits,
266266
state=state,
267-
sample_ids=sample_ids,)
267+
sample_ids=outputs.sample_id)
268268
return finished, next_inputs, next_state
269269

270270
def finalize(self, outputs, final_state, sequence_lengths):
@@ -601,14 +601,14 @@ def step(self, time, inputs, state, name=None):
601601
logits, sample_ids, wrapper_outputs,
602602
attention_scores, attention_context)
603603

604-
return (outputs, sample_ids, logits, wrapper_state)
604+
return (outputs, wrapper_state)
605605

606-
def next_inputs(self, sample_ids, time, outputs, state):
606+
def next_inputs(self, time, outputs, state):
607607
(finished, next_inputs, state) = self._helper.next_inputs(
608608
time=time,
609-
outputs=outputs,
609+
outputs=outputs.logits,
610610
state=state,
611-
sample_ids=sample_ids)
611+
sample_ids=outputs.sample_id)
612612
return (finished, next_inputs, state)
613613

614614
def finalize(self, outputs, final_state, sequence_lengths):

texar/tf/modules/decoders/rnn_decoders_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -385,5 +385,6 @@ def test_beam_search_cell(self):
385385
for tvar in beam_cell.trainable_variables:
386386
self.assertTrue(tvar in decoder.trainable_variables)
387387

388+
388389
if __name__ == "__main__":
389390
tf.test.main()

texar/tf/modules/decoders/tf_helpers.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -615,13 +615,10 @@ def sample(self, time, outputs, state, name=None):
615615
sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
616616
return sample_ids
617617

618-
def next_inputs(self, time, outputs, state, sample_ids, name=None,
619-
reach_max_time=None):
618+
def next_inputs(self, time, outputs, state, sample_ids, name=None):
620619
"""Gets the inputs for next step."""
621620
finished = math_ops.equal(sample_ids, self._end_token)
622621
all_finished = math_ops.reduce_all(finished)
623-
if reach_max_time is not None:
624-
all_finished = tf.logical_or(all_finished, reach_max_time)
625622

626623
if self._embedding_args_cnt == 1:
627624
del time, outputs # unused by next_inputs_fn

texar/tf/modules/decoders/transformer_decoders.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -826,14 +826,14 @@ def step(self, time, inputs, state, name=None):
826826
wrapper_outputs = TransformerDecoderOutput(
827827
logits=outputs,
828828
sample_id=sample_ids)
829-
return (wrapper_outputs, sample_ids, outputs, state)
829+
return (wrapper_outputs, state)
830830

831-
def next_inputs(self, sample_ids, time, outputs, state):
831+
def next_inputs(self, time, outputs, state):
832832
(finished, next_inputs, state) = self._helper.next_inputs(
833833
time=time,
834-
outputs=outputs,
834+
outputs=outputs.logits,
835835
state=state,
836-
sample_ids=sample_ids)
836+
sample_ids=outputs.sample_id)
837837
return (finished, next_inputs, state)
838838

839839
def finalize(self, outputs, final_state, sequence_lengths):

texar/tf/utils/beam_search.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# Modifications copyright (C) 2019 Texar
1717
# ==============================================================================
1818
"""
19-
Implemetation of beam seach with penalties.
20-
Adapted from tensor2tensor repositor.
19+
Implementation of beam search with penalties.
20+
Adapted from tensor2tensor repository.
2121
"""
2222

2323
from __future__ import absolute_import
@@ -32,6 +32,7 @@
3232
# Default value for INF
3333
INF = 1. * 1e7
3434

35+
3536
def _merge_beam_dim(tensor):
3637
"""Reshapes first two dimensions in to single dimension.
3738
@@ -41,6 +42,8 @@ def _merge_beam_dim(tensor):
4142
Returns:
4243
Reshaped tensor of shape [A*B, ...]
4344
"""
45+
if not isinstance(tensor, tf.Tensor) or not tensor.get_shape().as_list():
46+
return tensor
4447
shape = shape_list(tensor)
4548
shape[0] *= shape[1] # batch -> batch * beam_size
4649
shape.pop(1) # Remove beam dim
@@ -58,6 +61,8 @@ def _unmerge_beam_dim(tensor, batch_size, beam_size):
5861
Returns:
5962
Reshaped tensor of shape [batch_size, beam_size, ...]
6063
"""
64+
if not isinstance(tensor, tf.Tensor) or not tensor.get_shape().as_list():
65+
return tensor
6166
shape = shape_list(tensor)
6267
new_shape = [batch_size] + [beam_size] + shape[1:]
6368
return tf.reshape(tensor, new_shape)
@@ -73,6 +78,8 @@ def _expand_to_beam_size(tensor, beam_size):
7378
Returns:
7479
Tiled tensor [batch_size, beam_size, ...]
7580
"""
81+
if not isinstance(tensor, tf.Tensor) or not tensor.get_shape().as_list():
82+
return tensor
7683
tensor = tf.expand_dims(tensor, axis=1)
7784
tile_dims = [1] * tensor.shape.ndims
7885
tile_dims[1] = beam_size
@@ -173,6 +180,9 @@ def compute_topk_scores_and_seq(sequences, scores, scores_to_gather, flags,
173180
# operations with tfdbg. Clients can capture these tensors by watching
174181
# these node names.
175182
def gather(tensor, name):
183+
if not isinstance(tensor,
184+
tf.Tensor) or not tensor.get_shape().as_list():
185+
return tensor
176186
return tf.gather_nd(tensor, top_coordinates, name=(prefix + name))
177187
topk_seq = gather(sequences, "_topk_seq")
178188
topk_flags = gather(flags, "_topk_flags")
@@ -196,7 +206,7 @@ def beam_search(symbols_to_logits_fn,
196206
stop_early=True):
197207
"""Beam search with length penalties.
198208
199-
Requires a function that can take the currently decoded sybmols and
209+
Requires a function that can take the currently decoded symbols and
200210
return the logits for the next symbol. The implementation is inspired
201211
by https://arxiv.org/abs/1609.08144.
202212
@@ -255,11 +265,11 @@ def beam_search(symbols_to_logits_fn,
255265
# Expand each batch and state to beam_size
256266
alive_seq = _expand_to_beam_size(initial_ids, beam_size)
257267
alive_seq = tf.expand_dims(alive_seq, axis=2)
258-
#(batch_size, beam_size, 1)
268+
269+
# (batch_size, beam_size, 1)
259270
if states:
260271
states = nest.map_structure(
261-
lambda state: _expand_to_beam_size(state, beam_size),
262-
states)
272+
lambda state: _expand_to_beam_size(state, beam_size), states)
263273
else:
264274
states = {}
265275

@@ -384,7 +394,7 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
384394
if states:
385395
flat_states = nest.map_structure(_merge_beam_dim, states)
386396
flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i,
387-
flat_states)
397+
flat_states)
388398
states = nest.map_structure(
389399
lambda t: _unmerge_beam_dim(t, batch_size, beam_size),
390400
flat_states)
@@ -435,20 +445,19 @@ def grow_topk(i, alive_seq, alive_log_probs, states):
435445
topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
436446
if states:
437447
states = nest.map_structure(
438-
lambda state: tf.gather_nd(state, topk_coordinates),
439-
states)
448+
lambda state: tf.gather_nd(state, topk_coordinates), states)
440449

441450
# Append the most probable alive
442451
topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)],
443-
axis=2)
452+
axis=2)
444453

445454
topk_finished = tf.equal(topk_ids, eos_id)
446455

447456
return topk_seq, topk_log_probs, topk_scores, topk_finished, states
448457

449458
def inner_loop(i, alive_seq, alive_log_probs, finished_seq,
450-
finished_scores, finished_flags, states):
451-
"""Inner beam seach loop.
459+
finished_scores, finished_flags, states):
460+
"""Inner beam search loop.
452461
453462
There are three groups of tensors, alive, finished, and topk.
454463
The alive group contains information about the current alive

texar/tf/utils/dynamic_decode.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,6 @@ def dynamic_decode(decoder,
182182
type(decoder))
183183

184184
with tf.variable_scope(scope, "decoder") as varscope:
185-
# Determine context types.
186-
187185
if maximum_iterations is not None:
188186
maximum_iterations = tf.convert_to_tensor(
189187
maximum_iterations, dtype=tf.int32, name="maximum_iterations")
@@ -249,14 +247,13 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
249247
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
250248
next_sequence_lengths)`.
251249
"""
252-
(next_outputs, sample_ids, logits, state) = decoder.step(
253-
time, inputs, state)
254-
reach_max = tf.equal(time+1, decoder.max_decoding_length)
250+
(next_outputs, state) = decoder.step(time, inputs, state)
251+
reach_max = tf.equal(time+1, maximum_iterations)
255252
(decoder_finished, next_inputs, decoder_state) = tf.cond(
256253
reach_max,
257-
lambda: (tf.cast(tf.ones(tf.shape(sample_ids)[0]), tf.bool),
258-
decoder._helper._start_inputs, state),
259-
lambda: decoder.next_inputs(sample_ids, time, logits, state)
254+
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
255+
inputs, state),
256+
lambda: decoder.next_inputs(time, next_outputs, state)
260257
)
261258
if decoder.tracks_own_finished:
262259
next_finished = decoder_finished

0 commit comments

Comments
 (0)