Skip to content

Commit

Permalink
[DataFrame] Implement where (ray-project#1989)
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-petersohn authored and robertnishihara committed May 9, 2018
1 parent d2c193e commit 72a3a6c
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 11 deletions.
126 changes: 123 additions & 3 deletions python/ray/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4464,9 +4464,105 @@ def remote_func(df):

def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
errors='raise', try_cast=False, raise_on_error=None):
raise NotImplementedError(
"To contribute to Pandas on Ray, please visit "
"github.com/ray-project/ray.")
"""Replaces values not meeting condition with values in other.
Args:
cond: A condition to be met, can be callable, array-like or a
DataFrame.
other: A value or DataFrame of values to use for setting this.
inplace: Whether or not to operate inplace.
axis: The axis to apply over. Only valid when a Series is passed
as other.
level: The MultiLevel index level to apply over.
errors: Whether or not to raise errors. Does nothing in Pandas.
try_cast: Try to cast the result back to the input type.
raise_on_error: Whether to raise invalid datatypes (deprecated).
Returns:
A new DataFrame with the replaced values.
"""

inplace = validate_bool_kwarg(inplace, 'inplace')

if isinstance(other, pd.Series) and axis is None:
raise ValueError("Must specify axis=0 or 1")

if level is not None:
raise NotImplementedError("Multilevel Index not yet supported on "
"Pandas on Ray.")

axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0

cond = cond(self) if callable(cond) else cond

if not isinstance(cond, DataFrame):
if not hasattr(cond, 'shape'):
cond = np.asanyarray(cond)
if cond.shape != self.shape:
raise ValueError("Array conditional must be same shape as "
"self")
cond = DataFrame(cond, index=self.index, columns=self.columns)

zipped_partitions = self._copartition(cond, self.index)
args = (False, axis, level, errors, try_cast, raise_on_error)

if isinstance(other, DataFrame):
other_zipped = (v for k, v in self._copartition(other,
self.index))

new_partitions = [_where_helper.remote(k, v, next(other_zipped),
self.columns, cond.columns,
other.columns, *args)
for k, v in zipped_partitions]

# Series has to be treated specially because we're operating on row
# partitions from here on.
elif isinstance(other, pd.Series):
if axis == 0:
# Pandas determines which index to use based on axis.
other = other.reindex(self.index)
other.index = pd.RangeIndex(len(other))

# Since we're working on row partitions, we have to partition
# the Series based on the partitioning of self (since both
# self and cond are co-partitioned by self.
other_builder = []
for length in self._row_metadata._lengths:
other_builder.append(other[:length])
other = other[length:]
# Resetting the index here ensures that we apply each part
# to the correct row within the partitions.
other.index = pd.RangeIndex(len(other))

other = (obj for obj in other_builder)

new_partitions = [_where_helper.remote(k, v, next(other,
pd.Series()),
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]
else:
other = other.reindex(self.columns)
new_partitions = [_where_helper.remote(k, v, other,
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]

else:
new_partitions = [_where_helper.remote(k, v, other, self.columns,
cond.columns, None, *args)
for k, v in zipped_partitions]

if inplace:
self._update_inplace(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)
else:
return DataFrame(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)

def xs(self, key, axis=0, level=None, drop_level=True):
raise NotImplementedError(
Expand Down Expand Up @@ -5093,3 +5189,27 @@ def _merge_columns(left_columns, right_columns, *args):
return pd.DataFrame(columns=left_columns, index=[0], dtype='uint8').merge(
pd.DataFrame(columns=right_columns, index=[0], dtype='uint8'),
*args).columns


@ray.remote
def _where_helper(left, cond, other, left_columns, cond_columns,
other_columns, *args):

left = pd.concat(ray.get(left.tolist()), axis=1)
# We have to reset the index and columns here because we are coming
# from blocks and the axes are set according to the blocks. We have
# already correctly copartitioned everything, so there's no
# correctness problems with doing this.
left.reset_index(inplace=True, drop=True)
left.columns = left_columns

cond = pd.concat(ray.get(cond.tolist()), axis=1)
cond.reset_index(inplace=True, drop=True)
cond.columns = cond_columns

if isinstance(other, np.ndarray):
other = pd.concat(ray.get(other.tolist()), axis=1)
other.reset_index(inplace=True, drop=True)
other.columns = other_columns

return left.where(cond, other, *args)
32 changes: 29 additions & 3 deletions python/ray/dataframe/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df):


def test_where():
ray_df = create_test_dataframe()
pandas_df = pd.DataFrame(np.random.randn(100, 10),
columns=list('abcdefghij'))
ray_df = rdf.DataFrame(pandas_df)

with pytest.raises(NotImplementedError):
ray_df.where(None)
pandas_cond_df = pandas_df % 5 < 2
ray_cond_df = ray_df % 5 < 2

pandas_result = pandas_df.where(pandas_cond_df, -pandas_df)
ray_result = ray_df.where(ray_cond_df, -ray_df)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df.loc[3]

pandas_result = pandas_df.where(pandas_cond_df, other, axis=1)
ray_result = ray_df.where(ray_cond_df, other, axis=1)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df['e']

pandas_result = pandas_df.where(pandas_cond_df, other, axis=0)
ray_result = ray_df.where(ray_cond_df, other, axis=0)

assert ray_df_equals_pandas(ray_result, pandas_result)

pandas_result = pandas_df.where(pandas_df < 2, True)
ray_result = ray_df.where(ray_df < 2, True)

assert ray_df_equals_pandas(ray_result, pandas_result)


def test_xs():
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ def to_pandas(df):
Returns:
A new pandas DataFrame.
"""
if df._row_partitions is not None:
pd_df = pd.concat(ray.get(df._row_partitions))
else:
pd_df = pd.concat(ray.get(df._col_partitions),
axis=1)
pd_df = pd.concat(ray.get(df._row_partitions), copy=False)
pd_df.index = df.index
pd_df.columns = df.columns
return pd_df
Expand Down

0 comments on commit 72a3a6c

Please sign in to comment.