-
Notifications
You must be signed in to change notification settings - Fork 372
Add dynamic_decode and Fix decoder issue #208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
54427fc
to
51b37dc
Compare
51b37dc
to
2307d99
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
beam_search_decode
relying on BeamSearchDecoder
is not compatible to the new dynamic_decode. Can you update using beam_search
in texar.utils, and add to beam_search
the mode of standard beam search (make sure the same results with BeamSearchDecoder
). Then beam_search_decode
should be deleted.
This PR would make huge impact. Let's make sure it's bug-free, by reproducing example results in at least transformer/
, seq2seq_atten
, and text_style_transfer
.
81cab99
to
a2dee78
Compare
a2dee78
to
735f079
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pls merge once the comments are fixed
@@ -41,6 +42,8 @@ def _merge_beam_dim(tensor): | |||
Returns: | |||
Reshaped tensor of shape [A*B, ...] | |||
""" | |||
if not isinstance(tensor, tf.Tensor): | |||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why directly return tensor
here instead of converting it into a tf.Tensor
and continue the subsequent operations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we implement beam search decoding in AttentionRNNDecoder
by using this function, state
used in this function is AttentionWrapperState
which contains int
type attribute time
. This check is used to only reshape the correct attributes in state
. We also have similar checking in texar-pytorch
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for explaining. Can you add brief comments to make the code more readable? E.g., "if tensor
is not tf.Tensor, then tensor
is xxx, return directly".
Resolve #199