Skip to content

Commit b71ad08

Browse files
authored
Merge pull request #208 from gpengzhi/decoder-issue
Add dynamic_decode and Fix decoder issue
2 parents bd2dbe4 + 548da95 commit b71ad08

File tree

9 files changed

+400
-58
lines changed

9 files changed

+400
-58
lines changed

examples/seq2seq_exposure_bias/interpolation_decoder.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -107,16 +107,18 @@ def step(self, time, inputs, state, name=None):
107107
sample_ids = self._helper.sample(
108108
time=time, outputs=logits, state=[decoded_ids, wrapper_state])
109109

110-
(finished, next_inputs, next_state) = self._helper.next_inputs(
111-
time=time,
112-
outputs=logits,
113-
state=[decoded_ids, wrapper_state],
114-
sample_ids=sample_ids)
115-
116110
attention_scores = wrapper_state.alignments
117111
attention_context = wrapper_state.attention
118112
outputs = AttentionRNNDecoderOutput(
119113
logits, sample_ids, wrapper_outputs,
120114
attention_scores, attention_context)
121115

122-
return (outputs, next_state, next_inputs, finished)
116+
return (outputs, wrapper_state)
117+
118+
def next_inputs(self, time, outputs, state):
119+
(finished, next_inputs, next_state) = self._helper.next_inputs(
120+
time=time,
121+
outputs=outputs.logits,
122+
state=[state[0], state],
123+
sample_ids=outputs.sample_id)
124+
return (finished, next_inputs, next_state)

texar/tf/modules/decoders/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
# pylint: disable=wildcard-import
2323

24+
from texar.tf.modules.decoders.beam_search_decode import *
2425
from texar.tf.modules.decoders.rnn_decoder_base import *
2526
from texar.tf.modules.decoders.rnn_decoders import *
2627
from texar.tf.modules.decoders.tf_helpers import *
2728
from texar.tf.modules.decoders.rnn_decoder_helpers import *
2829
from texar.tf.modules.decoders.transformer_decoders import *
29-
from texar.tf.modules.decoders.beam_search_decode import *
+338
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Modifications copyright (C) 2019 Texar
16+
# ==============================================================================
17+
"""
18+
Utility functions for decoding. This file is modified from
19+
`tf.contrib.seq2seq.dynamic_decode`.
20+
"""
21+
22+
from __future__ import absolute_import
23+
from __future__ import print_function
24+
from __future__ import division
25+
from __future__ import unicode_literals
26+
27+
# pylint: disable=invalid-name, no-member, protected-access
28+
29+
import tensorflow as tf
30+
from tensorflow.contrib.seq2seq import Decoder as TFDecoder
31+
from tensorflow.python.framework import tensor_shape
32+
from tensorflow.python.util import nest
33+
34+
35+
__all__ = [
36+
"dynamic_decode"
37+
]
38+
39+
40+
def _concat(prefix, suffix, static=False):
41+
r"""Concat that enables int, Tensor, or TensorShape values.
42+
This function takes a size specification, which can be an integer, a
43+
TensorShape, or a Tensor, and converts it into a concatenated Tensor
44+
(if static = False) or a list of integers (if static = True).
45+
46+
Args:
47+
prefix: The prefix; usually the batch size (and/or time step size).
48+
(TensorShape, int, or Tensor.)
49+
suffix: TensorShape, int, or Tensor.
50+
static: If `True`, return a python list with possibly unknown
51+
dimensions. Otherwise return a `Tensor`.
52+
53+
Returns:
54+
shape: the concatenation of prefix and suffix.
55+
56+
Raises:
57+
ValueError: if `suffix` is not a scalar or vector (or TensorShape).
58+
ValueError: if prefix or suffix was `None` and asked for dynamic
59+
Tensors out.
60+
"""
61+
if isinstance(prefix, tf.Tensor):
62+
p = prefix
63+
p_static = tf.get_static_value(prefix)
64+
if p.shape.ndims == 0:
65+
p = tf.expand_dims(p, 0)
66+
elif p.shape.ndims != 1:
67+
raise ValueError("prefix tensor must be either a scalar or vector, "
68+
"but saw tensor: %s" % p)
69+
else:
70+
p = tensor_shape.as_shape(prefix)
71+
p_static = p.as_list() if p.ndims is not None else None
72+
p = (
73+
tf.constant(p.as_list(), dtype=tf.int32)
74+
if p.is_fully_defined() else None)
75+
if isinstance(suffix, tf.Tensor):
76+
s = suffix
77+
s_static = tf.get_static_value(suffix)
78+
if s.shape.ndims == 0:
79+
s = tf.expand_dims(s, 0)
80+
elif s.shape.ndims != 1:
81+
raise ValueError("suffix tensor must be either a scalar or vector, "
82+
"but saw tensor: %s" % s)
83+
else:
84+
s = tensor_shape.as_shape(suffix)
85+
s_static = s.as_list() if s.ndims is not None else None
86+
s = (
87+
tf.constant(s.as_list(), dtype=tf.int32)
88+
if s.is_fully_defined() else None)
89+
90+
if static:
91+
shape = tensor_shape.as_shape(p_static).concatenate(s_static)
92+
shape = shape.as_list() if shape.ndims is not None else None
93+
else:
94+
if p is None or s is None:
95+
raise ValueError("Provided a prefix or suffix of None: %s and %s" %
96+
(prefix, suffix))
97+
shape = tf.concat((p, s), 0)
98+
return shape
99+
100+
101+
def _zero_state_tensors(state_size, batch_size, dtype):
102+
r"""Create tensors of zeros based on state_size, batch_size, and dtype."""
103+
104+
def get_state_shape(s):
105+
r"""Combine s with batch_size to get a proper tensor shape."""
106+
107+
c = _concat(batch_size, s)
108+
size = tf.zeros(c, dtype=dtype)
109+
return size
110+
111+
return nest.map_structure(get_state_shape, state_size)
112+
113+
114+
def _create_zero_outputs(size, dtype, batch_size):
115+
r"""Create a zero outputs Tensor structure."""
116+
117+
def _create(s, d):
118+
return _zero_state_tensors(s, batch_size, d)
119+
120+
return nest.map_structure(_create, size, dtype)
121+
122+
123+
def _transpose_batch_time(x):
124+
r"""Transposes the batch and time dimensions of a Tensor.
125+
126+
If the input tensor has rank < 2 it returns the original tensor. Retains as
127+
much of the static shape information as possible.
128+
129+
Args:
130+
x: A Tensor.
131+
132+
Returns:
133+
x transposed along the first two dimensions.
134+
"""
135+
x_static_shape = x.get_shape()
136+
if x_static_shape.rank is not None and x_static_shape.rank < 2:
137+
return x
138+
139+
x_rank = tf.rank(x)
140+
x_t = tf.transpose(
141+
x, tf.concat(([1, 0], tf.range(2, x_rank)), axis=0))
142+
x_t.set_shape(
143+
tensor_shape.TensorShape(
144+
[x_static_shape.dims[1].value,
145+
x_static_shape.dims[0].value]).concatenate(x_static_shape[2:]))
146+
return x_t
147+
148+
149+
def dynamic_decode(decoder,
150+
output_time_major=False,
151+
impute_finished=False,
152+
maximum_iterations=None,
153+
parallel_iterations=32,
154+
swap_memory=False,
155+
scope=None):
156+
r"""Perform dynamic decoding with `decoder`.
157+
158+
Calls initialize() once and step() repeatedly on the Decoder object.
159+
160+
Args:
161+
decoder: A `Decoder` instance.
162+
output_time_major: Python boolean. Default: `False` (batch major). If
163+
`True`, outputs are returned as time major tensors (this mode is faster).
164+
Otherwise, outputs are returned as batch major tensors (this adds extra
165+
time to the computation).
166+
impute_finished: Python boolean. If `True`, then states for batch
167+
entries which are marked as finished get copied through and the
168+
corresponding outputs get zeroed out. This causes some slowdown at
169+
each time step, but ensures that the final state and outputs have
170+
the correct values and that backprop ignores time steps that were
171+
marked as finished.
172+
maximum_iterations: `int32` scalar, maximum allowed number of decoding
173+
steps. Default is `None` (decode until the decoder is fully done).
174+
parallel_iterations: Argument passed to `tf.while_loop`.
175+
swap_memory: Argument passed to `tf.while_loop`.
176+
scope: Optional variable scope to use.
177+
178+
Returns:
179+
`(final_outputs, final_state, final_sequence_lengths)`.
180+
Raises:
181+
TypeError: if `decoder` is not an instance of `Decoder`.
182+
ValueError: if `maximum_iterations` is provided but is not a scalar.
183+
"""
184+
if not isinstance(decoder, TFDecoder):
185+
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
186+
type(decoder))
187+
188+
with tf.variable_scope(scope, "decoder") as varscope:
189+
if maximum_iterations is not None:
190+
maximum_iterations = tf.convert_to_tensor(
191+
maximum_iterations, dtype=tf.int32, name="maximum_iterations")
192+
if maximum_iterations.get_shape().ndims != 0:
193+
raise ValueError("maximum_iterations must be a scalar")
194+
195+
initial_finished, initial_inputs, initial_state = decoder.initialize()
196+
197+
zero_outputs = _create_zero_outputs(decoder.output_size,
198+
decoder.output_dtype,
199+
decoder.batch_size)
200+
201+
if maximum_iterations is not None:
202+
initial_finished = tf.logical_or(
203+
initial_finished, 0 >= maximum_iterations)
204+
initial_sequence_lengths = tf.zeros_like(
205+
initial_finished, dtype=tf.int32)
206+
initial_time = tf.constant(0, dtype=tf.int32)
207+
208+
def _shape(batch_size, from_shape):
209+
if (not isinstance(from_shape, tensor_shape.TensorShape) or
210+
from_shape.ndims == 0):
211+
return None
212+
else:
213+
batch_size = tf.get_static_value(
214+
tf.convert_to_tensor(
215+
batch_size, name="batch_size"))
216+
return tensor_shape.TensorShape([batch_size]).\
217+
concatenate(from_shape)
218+
219+
dynamic_size = True
220+
221+
def _create_ta(s, d):
222+
return tf.TensorArray(
223+
dtype=d,
224+
size=0 if dynamic_size else maximum_iterations,
225+
dynamic_size=dynamic_size,
226+
element_shape=_shape(decoder.batch_size, s))
227+
228+
initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
229+
decoder.output_dtype)
230+
231+
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
232+
finished, unused_sequence_lengths):
233+
cond = tf.logical_not(tf.reduce_all(finished))
234+
cond_time = (maximum_iterations is None or
235+
unused_time < maximum_iterations)
236+
ret = tf.logical_and(cond, tf.convert_to_tensor(cond_time))
237+
return ret
238+
239+
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
240+
r"""Internal while_loop body.
241+
242+
Args:
243+
time: scalar int32 tensor.
244+
outputs_ta: structure of TensorArray.
245+
state: (structure of) state tensors and TensorArrays.
246+
inputs: (structure of) input tensors.
247+
finished: bool tensor (keeping track of what's finished).
248+
sequence_lengths: int32 tensor (keeping track of time of finish).
249+
250+
Returns:
251+
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
252+
next_sequence_lengths)`.
253+
"""
254+
(next_outputs, state) = decoder.step(time, inputs, state)
255+
256+
# Check if the maximum iteration is met. If it is met, do not compute
257+
# the next inputs.
258+
reach_max = tf.equal(time+1, maximum_iterations)
259+
(decoder_finished, next_inputs, decoder_state) = tf.cond(
260+
reach_max,
261+
lambda: (tf.cast(tf.ones_like(finished), tf.bool),
262+
inputs, state),
263+
lambda: decoder.next_inputs(time, next_outputs, state)
264+
)
265+
if decoder.tracks_own_finished:
266+
next_finished = decoder_finished
267+
else:
268+
next_finished = tf.logical_or(decoder_finished, finished)
269+
next_sequence_lengths = tf.where(
270+
tf.logical_not(finished),
271+
tf.fill(tf.shape(sequence_lengths), time + 1),
272+
sequence_lengths)
273+
274+
nest.assert_same_structure(state, decoder_state)
275+
nest.assert_same_structure(outputs_ta, next_outputs)
276+
nest.assert_same_structure(inputs, next_inputs)
277+
278+
# Zero out output values past finish
279+
if impute_finished:
280+
emit = nest.map_structure(
281+
lambda out, zero: tf.where(finished, zero, out),
282+
next_outputs,
283+
zero_outputs)
284+
else:
285+
emit = next_outputs
286+
287+
# Copy through states past finish
288+
def _maybe_copy_state(new, cur):
289+
# TensorArrays and scalar states get passed through.
290+
if isinstance(cur, tf.TensorArray):
291+
pass_through = True
292+
else:
293+
new.set_shape(cur.shape)
294+
pass_through = (new.shape.ndims == 0)
295+
return new if pass_through else tf.where(finished, cur, new)
296+
297+
if impute_finished:
298+
next_state = nest.map_structure(
299+
_maybe_copy_state, decoder_state, state)
300+
else:
301+
next_state = decoder_state
302+
303+
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
304+
outputs_ta, emit)
305+
return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
306+
next_sequence_lengths)
307+
308+
res = tf.while_loop(
309+
condition,
310+
body,
311+
loop_vars=(
312+
initial_time,
313+
initial_outputs_ta,
314+
initial_state,
315+
initial_inputs,
316+
initial_finished,
317+
initial_sequence_lengths,
318+
),
319+
parallel_iterations=parallel_iterations,
320+
maximum_iterations=maximum_iterations,
321+
swap_memory=swap_memory)
322+
323+
final_outputs_ta = res[1]
324+
final_state = res[2]
325+
final_sequence_lengths = res[5]
326+
327+
final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
328+
329+
try:
330+
final_outputs, final_state = decoder.finalize(
331+
final_outputs, final_state, final_sequence_lengths)
332+
except NotImplementedError:
333+
pass
334+
335+
if not output_time_major:
336+
final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
337+
338+
return final_outputs, final_state, final_sequence_lengths

texar/tf/modules/decoders/rnn_decoder_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,13 @@
2727

2828
import tensorflow as tf
2929
from tensorflow.contrib.seq2seq import Decoder as TFDecoder
30-
from tensorflow.contrib.seq2seq import dynamic_decode
3130
from tensorflow.python.framework import tensor_shape
3231
from tensorflow.python.util import nest
3332

3433
from texar.tf.core import layers
3534
from texar.tf.utils import utils
3635
from texar.tf.utils.mode import is_train_mode, is_train_mode_py
36+
from texar.tf.modules.decoders.dynamic_decode import dynamic_decode
3737
from texar.tf.module_base import ModuleBase
3838
from texar.tf.modules.decoders import rnn_decoder_helpers
3939
from texar.tf.utils.dtypes import is_callable

0 commit comments

Comments
 (0)