Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataFrame] Implement where #1989

Merged
merged 9 commits into from
May 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 123 additions & 3 deletions python/ray/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4464,9 +4464,105 @@ def remote_func(df):

def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
errors='raise', try_cast=False, raise_on_error=None):
raise NotImplementedError(
"To contribute to Pandas on Ray, please visit "
"github.com/ray-project/ray.")
"""Replaces values not meeting condition with values in other.

Args:
cond: A condition to be met, can be callable, array-like or a
DataFrame.
other: A value or DataFrame of values to use for setting this.
inplace: Whether or not to operate inplace.
axis: The axis to apply over. Only valid when a Series is passed
as other.
level: The MultiLevel index level to apply over.
errors: Whether or not to raise errors. Does nothing in Pandas.
try_cast: Try to cast the result back to the input type.
raise_on_error: Whether to raise invalid datatypes (deprecated).

Returns:
A new DataFrame with the replaced values.
"""

inplace = validate_bool_kwarg(inplace, 'inplace')

if isinstance(other, pd.Series) and axis is None:
raise ValueError("Must specify axis=0 or 1")

if level is not None:
raise NotImplementedError("Multilevel Index not yet supported on "
"Pandas on Ray.")

axis = pd.DataFrame()._get_axis_number(axis) if axis is not None else 0

cond = cond(self) if callable(cond) else cond

if not isinstance(cond, DataFrame):
if not hasattr(cond, 'shape'):
cond = np.asanyarray(cond)
if cond.shape != self.shape:
raise ValueError("Array conditional must be same shape as "
"self")
cond = DataFrame(cond, index=self.index, columns=self.columns)

zipped_partitions = self._copartition(cond, self.index)
args = (False, axis, level, errors, try_cast, raise_on_error)

if isinstance(other, DataFrame):
other_zipped = (v for k, v in self._copartition(other,
self.index))

new_partitions = [_where_helper.remote(k, v, next(other_zipped),
self.columns, cond.columns,
other.columns, *args)
for k, v in zipped_partitions]

# Series has to be treated specially because we're operating on row
# partitions from here on.
elif isinstance(other, pd.Series):
if axis == 0:
# Pandas determines which index to use based on axis.
other = other.reindex(self.index)
other.index = pd.RangeIndex(len(other))

# Since we're working on row partitions, we have to partition
# the Series based on the partitioning of self (since both
# self and cond are co-partitioned by self.
other_builder = []
for length in self._row_metadata._lengths:
other_builder.append(other[:length])
other = other[length:]
# Resetting the index here ensures that we apply each part
# to the correct row within the partitions.
other.index = pd.RangeIndex(len(other))

other = (obj for obj in other_builder)

new_partitions = [_where_helper.remote(k, v, next(other,
pd.Series()),
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]
else:
other = other.reindex(self.columns)
new_partitions = [_where_helper.remote(k, v, other,
self.columns,
cond.columns,
None, *args)
for k, v in zipped_partitions]

else:
new_partitions = [_where_helper.remote(k, v, other, self.columns,
cond.columns, None, *args)
for k, v in zipped_partitions]

if inplace:
self._update_inplace(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)
else:
return DataFrame(row_partitions=new_partitions,
row_metadata=self._row_metadata,
col_metadata=self._col_metadata)

def xs(self, key, axis=0, level=None, drop_level=True):
raise NotImplementedError(
Expand Down Expand Up @@ -5093,3 +5189,27 @@ def _merge_columns(left_columns, right_columns, *args):
return pd.DataFrame(columns=left_columns, index=[0], dtype='uint8').merge(
pd.DataFrame(columns=right_columns, index=[0], dtype='uint8'),
*args).columns


@ray.remote
def _where_helper(left, cond, other, left_columns, cond_columns,
other_columns, *args):

left = pd.concat(ray.get(left.tolist()), axis=1)
# We have to reset the index and columns here because we are coming
# from blocks and the axes are set according to the blocks. We have
# already correctly copartitioned everything, so there's no
# correctness problems with doing this.
left.reset_index(inplace=True, drop=True)
left.columns = left_columns

cond = pd.concat(ray.get(cond.tolist()), axis=1)
cond.reset_index(inplace=True, drop=True)
cond.columns = cond_columns

if isinstance(other, np.ndarray):
other = pd.concat(ray.get(other.tolist()), axis=1)
other.reset_index(inplace=True, drop=True)
other.columns = other_columns

return left.where(cond, other, *args)
32 changes: 29 additions & 3 deletions python/ray/dataframe/test/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3053,10 +3053,36 @@ def test_var(ray_df, pandas_df):


def test_where():
ray_df = create_test_dataframe()
pandas_df = pd.DataFrame(np.random.randn(100, 10),
columns=list('abcdefghij'))
ray_df = rdf.DataFrame(pandas_df)

with pytest.raises(NotImplementedError):
ray_df.where(None)
pandas_cond_df = pandas_df % 5 < 2
ray_cond_df = ray_df % 5 < 2

pandas_result = pandas_df.where(pandas_cond_df, -pandas_df)
ray_result = ray_df.where(ray_cond_df, -ray_df)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df.loc[3]

pandas_result = pandas_df.where(pandas_cond_df, other, axis=1)
ray_result = ray_df.where(ray_cond_df, other, axis=1)

assert ray_df_equals_pandas(ray_result, pandas_result)

other = pandas_df['e']

pandas_result = pandas_df.where(pandas_cond_df, other, axis=0)
ray_result = ray_df.where(ray_cond_df, other, axis=0)

assert ray_df_equals_pandas(ray_result, pandas_result)

pandas_result = pandas_df.where(pandas_df < 2, True)
ray_result = ray_df.where(ray_df < 2, True)

assert ray_df_equals_pandas(ray_result, pandas_result)


def test_xs():
Expand Down
6 changes: 1 addition & 5 deletions python/ray/dataframe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,7 @@ def to_pandas(df):
Returns:
A new pandas DataFrame.
"""
if df._row_partitions is not None:
pd_df = pd.concat(ray.get(df._row_partitions))
else:
pd_df = pd.concat(ray.get(df._col_partitions),
axis=1)
pd_df = pd.concat(ray.get(df._row_partitions), copy=False)
pd_df.index = df.index
pd_df.columns = df.columns
return pd_df
Expand Down