Skip to content

Commit

Permalink
Expanding DataPipe to support DataFrames (pytorch#71931)
Browse files Browse the repository at this point in the history
Differential Revision: [D37500516](https://our.internmc.facebook.com/intern/diff/D37500516)
Pull Request resolved: pytorch#71931
Approved by: https://github.com/ejguan
  • Loading branch information
VitalyFedyunin authored and pytorchmergebot committed Jul 8, 2022
1 parent 79a502f commit bcab525
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 55 deletions.
46 changes: 40 additions & 6 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,6 @@ def operations(df):
self.compare_capture_and_eager(operations)


@skipIf(True, "Fix DataFramePipes Tests")
class TestDataFramesPipes(TestCase):
"""
Most of test will fail if pandas instaled, but no dill available.
Expand All @@ -520,7 +519,9 @@ def test_capture(self):
dp_numbers = self._get_datapipe().map(lambda x: (x[0], x[1], x[1] + 3 * x[0]))
df_numbers = self._get_dataframes_pipe()
df_numbers['k'] = df_numbers['j'] + df_numbers.i * 3
self.assertEqual(list(dp_numbers), list(df_numbers))
expected = list(dp_numbers)
actual = list(df_numbers)
self.assertEquals(expected, actual)

@skipIfNoDataFrames
@skipIfNoDill
Expand All @@ -531,7 +532,7 @@ def test_shuffle(self):
dp_numbers = self._get_datapipe(range=1000)
df_result = [tuple(item) for item in df_numbers]
self.assertNotEqual(list(dp_numbers), df_result)
self.assertEqual(list(dp_numbers), sorted(df_result))
self.assertEquals(list(dp_numbers), sorted(df_result))

@skipIfNoDataFrames
@skipIfNoDill
Expand All @@ -541,20 +542,53 @@ def test_batch(self):
last_batch = df_numbers_list[-1]
self.assertEqual(4, len(last_batch))
unpacked_batch = [tuple(row) for row in last_batch]
self.assertEqual([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)
self.assertEquals([(96, 0), (97, 1), (98, 2), (99, 0)], unpacked_batch)

@skipIfNoDataFrames
@skipIfNoDill
def test_unbatch(self):
df_numbers = self._get_dataframes_pipe(range=100).batch(8).batch(3)
dp_numbers = self._get_datapipe(range=100)
self.assertEqual(list(dp_numbers), list(df_numbers.unbatch(2)))
self.assertEquals(list(dp_numbers), list(df_numbers.unbatch(2)))

@skipIfNoDataFrames
@skipIfNoDill
def test_filter(self):
df_numbers = self._get_dataframes_pipe(range=10).filter(lambda x: x.i > 5)
self.assertEqual([(6, 0), (7, 1), (8, 2), (9, 0)], list(df_numbers))
actual = list(df_numbers)
self.assertEquals([(6, 0), (7, 1), (8, 2), (9, 0)], actual)

@skipIfNoDataFrames
@skipIfNoDill
def test_collate(self):
def collate_i(column):
return column.sum()

def collate_j(column):
return column.prod()
df_numbers = self._get_dataframes_pipe(range=30).batch(3)
df_numbers = df_numbers.collate({'j': collate_j, 'i': collate_i})

expected_i = [3,
12,
21,
30,
39,
48,
57,
66,
75,
84, ]

actual_i = []
for i, j in df_numbers:
actual_i.append(i)
self.assertEqual(expected_i, actual_i)

actual_i = []
for item in df_numbers:
actual_i.append(item.i)
self.assertEqual(expected_i, actual_i)


class IDP_NoLen(IterDataPipe):
Expand Down
18 changes: 16 additions & 2 deletions torch/utils/data/datapipes/dataframe/dataframe_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_pandas = None
_WITH_PANDAS = None


def _try_import_pandas() -> bool:
try:
import pandas # type: ignore[import]
Expand All @@ -18,6 +19,7 @@ def _with_pandas() -> bool:
_WITH_PANDAS = _try_import_pandas()
return _WITH_PANDAS


class PandasWrapper:
@classmethod
def create_dataframe(cls, data, columns):
Expand All @@ -41,7 +43,7 @@ def is_column(cls, data):
def iterate(cls, data):
if not _with_pandas():
raise Exception("DataFrames prototype requires pandas to function")
for d in data:
for d in data.itertuples(index=False):
yield d

@classmethod
Expand All @@ -54,18 +56,25 @@ def concat(cls, buffer):
def get_item(cls, data, idx):
if not _with_pandas():
raise Exception("DataFrames prototype requires pandas to function")
return data[idx : idx + 1]
return data[idx: idx + 1]

@classmethod
def get_len(cls, df):
if not _with_pandas():
raise Exception("DataFrames prototype requires pandas to function")
return len(df.index)

@classmethod
def get_columns(cls, df):
if not _with_pandas():
raise Exception("DataFrames prototype requires pandas to function")
return list(df.columns.values.tolist())


# When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
default_wrapper = PandasWrapper


def get_df_wrapper():
return default_wrapper

Expand All @@ -85,6 +94,11 @@ def is_dataframe(data):
return wrapper.is_dataframe(data)


def get_columns(data):
wrapper = get_df_wrapper()
return wrapper.get_columns(data)


def is_column(data):
wrapper = get_df_wrapper()
return wrapper.is_column(data)
Expand Down
Loading

0 comments on commit bcab525

Please sign in to comment.