Skip to content
This repository was archived by the owner on Nov 22, 2022. It is now read-only.

fix BlockShardedTSVDataSource #832

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions pytext/data/sources/tsv.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,18 @@ class BlockShardedTSV:
"""

def __init__(
self, file, field_names=None, delimiter="\t", block_id=0, num_blocks=1
self,
file,
field_names=None,
delimiter="\t",
quoted=False,
block_id=0,
num_blocks=1,
):
self.file = file
self.field_names = field_names
self.delimiter = delimiter
self.quoted = quoted
self.block_id = block_id
self.num_blocks = num_blocks
csv.field_size_limit(sys.maxsize)
Expand All @@ -226,7 +233,7 @@ def __iter__(self):
(line.replace("\0", "") for line in iter(self.file.readline, "")),
fieldnames=self.field_names,
delimiter=self.delimiter,
quoting=csv.QUOTE_NONE,
quoting=csv.QUOTE_MINIMAL if self.quoted else csv.QUOTE_NONE,
)
# iterate until we're at the end of segment
for line in reader:
Expand All @@ -244,14 +251,17 @@ def __init__(self, rank=0, world_size=1, **kwargs):
# weird python syntax to call init of ShardedDataSource
super(TSVDataSource, self).__init__(schema=self.schema)

def _init_tsv(self, field_names, delimiter, train_file, test_file, eval_file):
def _init_tsv(
self, field_names, delimiter, train_file, test_file, eval_file, quoted
):
def make_tsv(file, rank=0, world_size=1):
return BlockShardedTSV(
file,
field_names=field_names,
delimiter=delimiter,
block_id=rank,
num_blocks=world_size,
quoted=quoted,
)

self._train_tsv = (
Expand All @@ -260,7 +270,7 @@ def make_tsv(file, rank=0, world_size=1):
self._test_tsv = make_tsv(test_file) if test_file else []
self._eval_tsv = make_tsv(eval_file) if eval_file else []
self._train_unsharded = (
TSV(train_file, field_names=field_names, delimiter=delimiter)
TSV(train_file, field_names=field_names, delimiter=delimiter, quoted=quoted)
if train_file
else []
)
Expand Down
47 changes: 46 additions & 1 deletion pytext/data/test/tsv_data_source_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
from typing import List

from pytext.data.sources.data_source import SafeFileWrapper
from pytext.data.sources.tsv import SessionTSVDataSource, TSVDataSource
from pytext.data.sources.tsv import (
BlockShardedTSVDataSource,
SessionTSVDataSource,
TSVDataSource,
)
from pytext.utils.test import import_tests_module


Expand Down Expand Up @@ -151,3 +155,44 @@ def test_read_session_data(self):
self.assertEqual(["int31", "int32", "int33"], example["intent"])
self.assertEqual(["g31", "g32", "g33"], example["goals"])
self.assertEqual(["0", "1", "1"], example["label"])


class BlockShardedTSVDataSourceTest(unittest.TestCase):
def test_quoting(self):
"""
The text column of the first row of this file opens a quote but
does not close it.
"""
data_source = BlockShardedTSVDataSource(
train_file=SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")),
test_file=None,
eval_file=None,
field_names=["label", "text"],
schema={"text": str, "label": str},
)

data = list(data_source.train_unsharded)
self.assertEqual(4, len(data))

data = list(data_source.train)
self.assertEqual(4, len(data))

def test_bad_quoting(self):
"""
The text column of the first row of this file opens a quote but
does not close it.
"""
data_source = BlockShardedTSVDataSource(
train_file=SafeFileWrapper(tests_module.test_file("test_tsv_quoting.tsv")),
test_file=None,
eval_file=None,
field_names=["label", "text"],
schema={"text": str, "label": str},
quoted=True,
)

data = list(data_source.train_unsharded)
self.assertEqual(1, len(data))

data = list(data_source.train)
self.assertEqual(1, len(data))