Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Add to_unicode_utf8() to text_encoder.py #1321

Merged
merged 1 commit into from
Jan 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import os
import random
import tarfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -157,10 +156,7 @@ def fix_run_on_sents(line):
summary = []
reading_highlights = False
for line in tf.gfile.Open(story_file, "rb"):
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
line = fix_run_on_sents(line)
if not line:
continue
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/cola.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -83,10 +82,7 @@ def _maybe_download_corpora(self, tmp_dir):

def example_generator(self, filename):
for line in tf.gfile.Open(filename, "rb"):
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
_, label, _, sent = line.split("\t")
yield {
"inputs": sent,
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/mrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import print_function

import os
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -95,10 +94,7 @@ def download_file(tdir, filepath, url):
def example_generator(self, filename, dev_ids, dataset_split):
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
l, id1, id2, s1, s2 = line.split("\t")
is_dev = [id1, id2] in dev_ids
if dataset_split == problem.DatasetSplit.TRAIN and is_dev:
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/multinli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import problem
Expand Down Expand Up @@ -87,10 +86,7 @@ def example_generator(self, filename):
label_list = self.class_labels(data_dir=None)
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
split_line = line.split("\t")
# Works for both splits even though dev has some extra human labels.
s1, s2 = split_line[8:10]
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/qnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -85,10 +84,7 @@ def example_generator(self, filename):
label_list = self.class_labels(data_dir=None)
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
_, s1, s2, l = line.split("\t")
inputs = [s1, s2]
l = label_list.index(l)
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/quora_qpairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -84,10 +83,7 @@ def example_generator(self, filename):
skipped = 0
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
split_line = line.split("\t")
if len(split_line) < 6:
skipped += 1
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/rte.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -85,10 +84,7 @@ def example_generator(self, filename):
label_list = self.class_labels(data_dir=None)
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
_, s1, s2, l = line.split("\t")
inputs = [s1, s2]
l = label_list.index(l)
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/scitail.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import problem
Expand Down Expand Up @@ -83,10 +82,7 @@ def _maybe_download_corpora(self, tmp_dir):
def example_generator(self, filename):
label_list = self.class_labels(data_dir=None)
for line in tf.gfile.Open(filename, "rb"):
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
split_line = line.split("\t")
s1, s2 = split_line[:2]
l = label_list.index(split_line[2])
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/sst_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -84,10 +83,7 @@ def _maybe_download_corpora(self, tmp_dir):
def example_generator(self, filename):
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
sent, label = line.split("\t")
yield {
"inputs": sent,
Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/stanford_nli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import problem
Expand Down Expand Up @@ -84,10 +83,7 @@ def example_generator(self, filename):
label_list = self.class_labels(data_dir=None)
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
split_line = line.split("\t")
# Works for both splits even though dev has some extra human labels.
s1, s2 = split_line[5:7]
Expand Down
4 changes: 4 additions & 0 deletions tensor2tensor/data_generators/text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def to_unicode_ignore_errors(s):
return to_unicode(s, ignore_errors=True)


def to_unicode_utf8(s):
return unicode(s, "utf-8") if six.PY2 else s.decode("utf-8")


def strip_ids(ids, ids_to_strip):
"""Strip ids_to_strip from the end ids."""
ids = list(ids)
Expand Down
10 changes: 2 additions & 8 deletions tensor2tensor/data_generators/wiki_revision_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,12 @@
import re
import subprocess

import six

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import text_encoder

import tensorflow as tf


def to_unicode(s):
return unicode(s, "utf-8") if six.PY2 else s.decode("utf-8")


def include_revision(revision_num, skip_factor=1.1):
"""Decide whether to include a revision.

Expand Down Expand Up @@ -118,7 +112,7 @@ def get_title(page):
assert start_pos != -1
assert end_pos != -1
start_pos += len("<title>")
return to_unicode(page[start_pos:end_pos])
return text_encoder.to_unicode_utf8(page[start_pos:end_pos])


def get_id(page):
Expand Down Expand Up @@ -257,7 +251,7 @@ def get_text(revision, strip=True):
ret = revision[end_tag_pos:end_pos]
if strip:
ret = strip_text(ret)
ret = to_unicode(ret)
ret = text_encoder.to_unicode_utf8(ret)
return ret


Expand Down
6 changes: 1 addition & 5 deletions tensor2tensor/data_generators/wnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import os
import zipfile
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
Expand Down Expand Up @@ -88,10 +87,7 @@ def _maybe_download_corpora(self, tmp_dir):
def example_generator(self, filename):
for idx, line in enumerate(tf.gfile.Open(filename, "rb")):
if idx == 0: continue # skip header
if six.PY2:
line = unicode(line.strip(), "utf-8")
else:
line = line.strip().decode("utf-8")
line = text_encoder.to_unicode_utf8(line.strip())
_, s1, s2, l = line.split("\t")
inputs = [s1, s2]
yield {
Expand Down