-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fixed tensor2tensor language modeling decode #1282
Conversation
…er weight initialization
…ize without padding and EOS insertion)
Thanks for this and sorry for waiting so long! I had to remove parts and in a follow-up I'll unluckily remove more to make it work in the newest state. But thanks to is, LM decoding works now (at least at batch size 1 after my update), great thanks! |
PiperOrigin-RevId: 228809213
@lukaszkaiser As it stands, this is only guaranteed to work with batch size 1 for LM problems (since batch sizes > 1 will cause inputs of lesser length to have appended -- which would cause inaccurate decoding). This means decoding from file for LM will be unnecessarily slow. My fix was meant to group all sequences of same length together so never needs to be used. In my local testing the _decode_batch_input_fn_no_padding function I wrote works fine, and the results of LM are as expected. |
I should also mention my function requires _get_sorted_inputs(...) to be called first (which probably should be called regardless of whether or not my function is used, since sequences causing OOM can still be an issue in LM problems), which if you're testing on t2t v1.12.0 is no longer called for language modeling problems (now _get_language_modeling inputs(...) called instead). To keep support for SubwordTokenizer, _get_language_modeling_inputs(..) can probably be called after _get_sorted_inputs(...) if not has_input |
@lukaszkaiser A better solution than what I previously implemented would be to left pad the timing signal/positional embedding as well as left pad the inputs. That way shape is maintained so batching without regrouping by sequence length can be done and the timing signal isn't shifted. |
* Changed reuse val from true to tf.AUTO_REUSE in top to allow for proper weight initialization * Fixed decoding from file for language modeling problems (when has_input=False) * nan * Fixed language modeling decoding from file (allowing variable batch size without padding and EOS insertion) * Delete misc.xml * Delete modules.xml * Delete tensor2tensor.iml * Delete vcs.xml * Delete workspace.xml
PiperOrigin-RevId: 228809213
Allowing for variable batch sizes, removes padding and EOS token insertion
(See: 1227)