Skip to content

Commit

Permalink
[WIP] Validating input_col for certain datapipes (pytorch#80267)
Browse files Browse the repository at this point in the history
Follow up from pytorch#79344.

Currently WIP due to multiple test failures.

Waiting for pytorch#80140 to land
Pull Request resolved: pytorch#80267
Approved by: https://github.com/ejguan
  • Loading branch information
bushshrub authored and pytorchmergebot committed Aug 24, 2022
1 parent 30a5583 commit 5c49c7b
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 4 deletions.
109 changes: 107 additions & 2 deletions test/test_datapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,24 @@ def fn_n1(d0, d1):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_sep_pos(d0, *args, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)
p_fn_cmplx = partial(fn_cmplx, d2=2)

def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
for constr in (list, tuple):
datapipe = dp.iter.IterableWrapper([constr((0, 1, 2)), constr((3, 4, 5)), constr((6, 7, 8))])
Expand All @@ -1231,14 +1249,33 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
self.assertEqual(list(res_dp), list(ref_dp))
# Reset
self.assertEqual(list(res_dp), list(ref_dp))
_helper(lambda data: data, fn_n1_def, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), fn_n1_def, [0, 1], 2)
_helper(lambda data: data, p_fn_n1, 0, 1)
_helper(lambda data: data, p_fn_cmplx, 0, 1)
_helper(lambda data: (data[0], data[1], data[0] + data[1]), p_fn_cmplx, [0, 1], 2)
_helper(lambda data: (data[0] + data[1], ), fn_n1_pos, [0, 1, 2])

# Replacing with one input column and default output column
_helper(lambda data: (data[0], -data[1], data[2]), fn_11, 1)
_helper(lambda data: (data[0], (-data[1], data[1]), data[2]), fn_1n, 1)
# The index of input column is out of range
_helper(None, fn_1n, 3, error=IndexError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, 1, error=TypeError)
_helper(None, fn_n1, 1, error=ValueError)
_helper(None, fn_n1, [0, 1, 2], error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, 0, error=ValueError)
_helper(None, lambda d0, d1: d0 + d1, [0, 1, 2], error=ValueError)
_helper(None, fn_cmplx, 0, 1, ValueError)
_helper(None, fn_n1_pos, 1, error=ValueError)
_helper(None, fn_n1_def, [0, 1, 2], 1, error=ValueError)
_helper(None, p_fn_n1, [0, 1], error=ValueError)
_helper(None, fn_1n, [1, 2], error=ValueError)
# _helper(None, p_fn_cmplx, [0, 1, 2], error=ValueError)
_helper(None, fn_n1_sep_pos, [0, 1, 2], error=ValueError)
# Fn has keyword-only arguments
_helper(None, fn_n1_kwargs, 1, error=ValueError)
_helper(None, fn_cmplx, [0, 1], 2, ValueError)

# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: (data[1], data[2] + data[0]), fn_n1, [2, 0])
Expand Down Expand Up @@ -1278,6 +1315,28 @@ def fn_n1(d0, d1):
def fn_nn(d0, d1):
return -d0, -d1, d0 + d1

def fn_n1_def(d0, d1=1):
return d0 + d1

p_fn_n1 = partial(fn_n1, d1=1)

def fn_n1_pos(d0, d1, *args):
return d0 + d1

def fn_n1_kwargs(d0, d1, **kwargs):
return d0 + d1

def fn_kwonly(*, d0, d1):
return d0 + d1

def fn_has_nondefault_kwonly(d0, *, d1):
return d0 + d1

def fn_cmplx(d0, d1=1, *args, d2, **kwargs):
return d0 + d1

p_fn_cmplx = partial(fn_cmplx, d2=2)

# Prevent modification in-place to support resetting
def _dict_update(data, newdata, remove_idx=None):
_data = dict(data)
Expand All @@ -1304,13 +1363,33 @@ def _helper(ref_fn, fn, input_col=None, output_col=None, error=None):
# Reset
self.assertEqual(list(res_dp), list(ref_dp))

_helper(lambda data: data, fn_n1_def, 'x', 'y')
_helper(lambda data: data, p_fn_n1, 'x', 'y')
_helper(lambda data: data, p_fn_cmplx, 'x', 'y')
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}),
p_fn_cmplx, ["x", "y", "z"], "z")

_helper(lambda data: _dict_update(data, {"z": data["x"] + data["y"]}), fn_n1_def, ['x', 'y'], 'z')

_helper(None, fn_n1_pos, 'x', error=ValueError)
_helper(None, fn_n1_kwargs, 'x', error=ValueError)
# non-default kw-only args
_helper(None, fn_kwonly, ['x', 'y'], error=ValueError)
_helper(None, fn_has_nondefault_kwonly, ['x', 'y'], error=ValueError)
_helper(None, fn_cmplx, ['x', 'y'], error=ValueError)


# Replacing with one input column and default output column
_helper(lambda data: _dict_update(data, {"y": -data["y"]}), fn_11, "y")
_helper(lambda data: _dict_update(data, {"y": (-data["y"], data["y"])}), fn_1n, "y")
# The key of input column is not in dict
_helper(None, fn_1n, "a", error=KeyError)
# Unmatched input columns with fn arguments
_helper(None, fn_n1, "y", error=TypeError)
_helper(None, fn_n1, "y", error=ValueError)
_helper(None, fn_1n, ["x", "y"], error=ValueError)
_helper(None, fn_n1_def, ["x", "y", "z"], error=ValueError)
_helper(None, p_fn_n1, ["x", "y"], error=ValueError)
_helper(None, fn_n1_kwargs, ["x", "y", "z"], error=ValueError)
# Replacing with multiple input columns and default output column (the left-most input column)
_helper(lambda data: _dict_update(data, {"z": data["x"] + data["z"]}, ["x"]), fn_n1, ["z", "x"])
_helper(lambda data: _dict_update(
Expand Down Expand Up @@ -1508,6 +1587,32 @@ def _mul_filter_fn(a, b):
input_col_2_dp = tuple_input_ds.filter(_mul_filter_fn, input_col=[0, 2])
self.assertEqual(list(input_col_2_dp), [(d - 1, d, d + 1) for d in range(5)])

# invalid input col
with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn, input_col=0)

p_mul_filter_fn = partial(_mul_filter_fn, b=1)
out = tuple_input_ds.filter(p_mul_filter_fn, input_col=0)
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])

def _mul_filter_fn_with_defaults(a, b=1):
return a + b < 10

out = tuple_input_ds.filter(_mul_filter_fn_with_defaults, input_col=0)
self.assertEqual(list(out), [(d - 1, d, d + 1) for d in range(10)])

def _mul_filter_fn_with_kw_only(*, a, b):
return a + b < 10

with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn_with_kw_only, input_col=0)

def _mul_filter_fn_with_kw_only_1_default(*, a, b=1):
return a + b < 10

with self.assertRaises(ValueError):
tuple_input_ds.filter(_mul_filter_fn_with_kw_only_1_default, input_col=0)

# __len__ Test: DataPipe has no valid len
with self.assertRaisesRegex(TypeError, r"has no len"):
len(filter_dp)
Expand Down
4 changes: 3 additions & 1 deletion torch/utils/data/datapipes/iter/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch.utils.data._utils.collate import default_collate
from torch.utils.data.datapipes.dataframe import dataframe_wrapper as df_wrapper
from torch.utils.data.datapipes.datapipe import IterDataPipe
from torch.utils.data.datapipes.utils.common import _check_unpickable_fn
from torch.utils.data.datapipes.utils.common import (_check_unpickable_fn,
validate_input_col)

__all__ = [
"CollatorIterDataPipe",
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
raise ValueError("`output_col` must be a single-element list or tuple")
output_col = output_col[0]
self.output_col = output_col
validate_input_col(fn, input_col)

def _apply_fn(self, data):
if self.input_col is None and self.output_col is None:
Expand Down
2 changes: 2 additions & 0 deletions torch/utils/data/datapipes/iter/selecting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
_check_unpickable_fn,
_deprecation_warning,
StreamWrapper,
validate_input_col
)


Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
self.drop_empty_batches = drop_empty_batches

self.input_col = input_col
validate_input_col(filter_fn, input_col)

def _apply_filter_fn(self, data) -> bool:
if self.input_col is None:
Expand Down
82 changes: 81 additions & 1 deletion torch/utils/data/datapipes/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.utils.data._utils.serialization import DILL_AVAILABLE

__all__ = [
"validate_input_col",
"StreamWrapper",
"get_file_binaries_from_pathnames",
"get_file_pathnames_from_root",
Expand All @@ -20,6 +21,86 @@
]


def validate_input_col(fn: Callable, input_col: Optional[Union[int, tuple, list]]):
"""
Checks that function used in a callable datapipe works with the input column
This simply ensures that the number of positional arguments matches the size
of the input column. The function must not contain any non-default
keyword-only arguments.
Examples:
>>> def f(a, b, *, c=1):
>>> return a + b + c
>>> def f_def(a, b=1, *, c=1):
>>> return a + b + c
>>> assert validate_input_col(f, [1, 2])
>>> assert validate_input_col(f_def, 1)
>>> assert validate_input_col(f_def, [1, 2])
Notes:
If the function contains variable positional (`inspect.VAR_POSITIONAL`) arguments,
for example, f(a, *args), the validator will accept any size of input column
greater than or equal to the number of positional arguments.
(in this case, 1).
Args:
fn: The function to check.
input_col: The input column to check.
Raises:
ValueError: If the function is not compatible with the input column.
"""
sig = inspect.signature(fn)
if isinstance(input_col, (list, tuple)):
input_col_size = len(input_col)
else:
input_col_size = 1

fn_name = str(fn)

pos = []
var_positional = False
non_default_kw_only = []

for p in sig.parameters.values():
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
pos.append(p)
elif p.kind is inspect.Parameter.VAR_POSITIONAL:
var_positional = True
elif p.kind is inspect.Parameter.KEYWORD_ONLY:
if p.default is p.empty:
non_default_kw_only.append(p)
else:
continue

if len(non_default_kw_only) > 0:
raise ValueError(
f"The function {fn_name} takes {len(non_default_kw_only)} "
f"non-default keyword-only parameters, which is not allowed."
)

if len(sig.parameters) < input_col_size:
if not var_positional:
raise ValueError(
f"The function {fn_name} takes {len(sig.parameters)} "
f"parameters, but {input_col_size} are required."
)
else:
if len(pos) > input_col_size:
if any(p.default is p.empty for p in pos[input_col_size:]):
raise ValueError(
f"The function {fn_name} takes {len(pos)} "
f"positional parameters, but {input_col_size} are required."
)
elif len(pos) < input_col_size:
if not var_positional:
raise ValueError(
f"The function {fn_name} takes {len(pos)} "
f"positional parameters, but {input_col_size} are required."
)


def _is_local_fn(fn):
# Functions or Methods
if hasattr(fn, "__code__"):
Expand All @@ -33,7 +114,6 @@ def _is_local_fn(fn):
return "<locals>" in fn_type.__qualname__
return False


def _check_unpickable_fn(fn: Callable):
"""
Checks function is pickable or not. If it is a lambda or local function, a UserWarning
Expand Down

0 comments on commit 5c49c7b

Please sign in to comment.