diff --git a/tensorflow_datasets/core/features/text/text_encoder.py b/tensorflow_datasets/core/features/text/text_encoder.py index 6977d7014bb..6f853e96dd6 100644 --- a/tensorflow_datasets/core/features/text/text_encoder.py +++ b/tensorflow_datasets/core/features/text/text_encoder.py @@ -230,7 +230,9 @@ def __init__(self, oov_buckets=1, oov_token="UNK", lowercase=False, - tokenizer=None): + tokenizer=None, + strip_vocab=True, + decode_token_separator=" "): """Constructs a TokenTextEncoder. To load from a file saved with `TokenTextEncoder.save_to_file`, use @@ -244,8 +246,14 @@ def __init__(self, lowercase: `bool`, whether to make all text and tokens lowercase. tokenizer: `Tokenizer`, responsible for converting incoming text into a list of tokens. + strip_vocab: `bool`, whether to strip whitespace from the beginning and + end of elements of `vocab_list`. + decode_token_separator: `str`, the string used to separate tokens when + decoding. """ - self._vocab_list = [tf.compat.as_text(el).strip() for el in vocab_list] + self._vocab_list = [tf.compat.as_text(el) for el in vocab_list] + if strip_vocab: + self._vocab_list = [el.strip() for el in self._vocab_list] self._lowercase = lowercase if self._lowercase: self._vocab_list = [t.lower() for t in self._vocab_list] @@ -261,6 +269,8 @@ def __init__(self, self._tokenizer = (tokenizer or Tokenizer(reserved_tokens=reserved_tokens)) self._user_defined_tokenizer = tokenizer + self._decode_token_separator = decode_token_separator + def encode(self, s): s = tf.compat.as_text(s) if self.lowercase: @@ -286,7 +296,7 @@ def decode(self, ids): tokens.append(self._vocab_list[int_id]) else: tokens.append(self._oov_token) - return " ".join(tokens) + return self._decode_token_separator.join(tokens) @property def vocab_size(self):