Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
e7e78ce
Added type checking and changed how variables were read in from kwargs
williamma12 Aug 29, 2018
d686848
Merge branch 'rewrite_backend' into rewrite_backend
devin-petersohn Aug 30, 2018
cf9b05d
Updated sample to new architecture
williamma12 Aug 30, 2018
2b94c77
Fixed merge conflict
williamma12 Aug 30, 2018
8724956
Made test_sample more rigourous
williamma12 Aug 31, 2018
2ad1c3b
Removed 'default=' from kwargs.get's
williamma12 Aug 31, 2018
502bd0d
Updated eval to the new backend
williamma12 Aug 31, 2018
992a8e2
Added two more tests for eval
williamma12 Aug 31, 2018
7cbb17a
Updated memory_usage to new backend
williamma12 Sep 1, 2018
b144b3d
Updated info and memory_usage to the new backend
williamma12 Sep 2, 2018
4e7adda
Updated info and memory_usage to be standalone tests and updated the …
williamma12 Sep 2, 2018
8a69de5
Updated info to do only one pass
williamma12 Sep 2, 2018
0ea925f
Updated info to do everything in one run with DataFrame
williamma12 Sep 2, 2018
8a7b320
Update info to do everything in one run with Series
williamma12 Sep 2, 2018
8585f8f
Updated info to do everything in one run with DataFrame
williamma12 Sep 2, 2018
e273288
Updated to get everything working and moved appropriate parts to Data…
williamma12 Sep 2, 2018
9d0f224
Fixed merge conflics
williamma12 Sep 2, 2018
5d52f7f
Removed extraneous print statement
williamma12 Sep 6, 2018
7571b3d
Moved dtypes stuff to data manager
williamma12 Sep 6, 2018
e70aaec
Fixed calculating dtypes to only doing a full_reduce instead of map_f…
williamma12 Sep 6, 2018
2d3a0a5
Merge branch 'data_manager_dtypes' into rewrite_backend
williamma12 Sep 6, 2018
59d6c41
Updated astype to new backend
williamma12 Sep 7, 2018
f53917f
Updated astype to new backend
williamma12 Sep 7, 2018
e1b01fc
Updated ftypes to new backend
williamma12 Sep 7, 2018
3c28c8f
Added dtypes argument to map_partitions
williamma12 Sep 7, 2018
f6a8ed6
Merge branch 'rewrite_backend' into rewrite_backend
williamma12 Sep 7, 2018
d71874a
Fixing dtypes
devin-petersohn Sep 7, 2018
09512dc
Merge branch 'dtypes_fix' into rewrite_backend
devin-petersohn Sep 7, 2018
5bf2517
Cleaning up dtype and merge issues
devin-petersohn Sep 7, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 120 additions & 36 deletions modin/data_management/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import numpy as np
import pandas
from pandas.compat import string_types
from pandas.core.dtypes.common import is_list_like
from pandas.core.dtypes.cast import find_common_type
from pandas.core.dtypes.common import (_get_dtype_from_object, is_list_like)

from .partitioning.partition_collections import BlockPartitions, RayBlockPartitions
from .partitioning.remote_partition import RayRemotePartition
Expand All @@ -16,11 +17,32 @@ class PandasDataManager(object):
with a Pandas backend. This logic is specific to Pandas.
"""

def __init__(self, block_partitions_object, index, columns):
def __init__(self, block_partitions_object, index, columns, dtypes=None):
assert isinstance(block_partitions_object, BlockPartitions)
self.data = block_partitions_object
self.index = index
self.columns = columns
if dtypes is not None:
self._dtype_cache = dtypes

# dtypes
_dtype_cache = None

def _get_dtype(self):
if self._dtype_cache is None:
map_func = lambda df: df.dtypes

def func(row):
return find_common_type(row.values)

self._dtype_cache = self.data.full_reduce(map_func, lambda df: df.apply(func, axis=0), 0)
self._dtype_cache.index = self.columns
return self._dtype_cache

def _set_dtype(self, dtypes):
self._dtype_cache = dtypes

dtypes = property(_get_dtype, _set_dtype)

# Index and columns objects
# These objects are currently not distributed.
Expand Down Expand Up @@ -53,7 +75,7 @@ def _set_columns(self, new_columns):
columns = property(_get_columns, _set_columns)
index = property(_get_index, _set_index)

# END Index and columns objects
# END Index, columns, and dtypes objects

def compute_index(self, axis, data_object, compute_diff=True):
"""Computes the index after a number of rows have been removed.
Expand Down Expand Up @@ -99,12 +121,12 @@ def _prepare_method(self, pandas_func, **kwargs):
def add_prefix(self, prefix):
cls = type(self)
new_column_names = self.columns.map(lambda x: str(prefix) + str(x))
return cls(self.data, self.index, new_column_names)
return cls(self.data, self.index, new_column_names, self.dtypes)

def add_suffix(self, suffix):
cls = type(self)
new_column_names = self.columns.map(lambda x: str(x) + str(suffix))
return cls(self.data, self.index, new_column_names)
return cls(self.data, self.index, new_column_names, self.dtypes)
# END Metadata modification methods

# Copy
Expand All @@ -113,7 +135,7 @@ def add_suffix(self, suffix):
# to prevent that.
def copy(self):
cls = type(self)
return cls(self.data.copy(), self.index.copy(), self.columns.copy())
return cls(self.data.copy(), self.index.copy(), self.columns.copy(), self.dtypes.copy())

# Append/Concat/Join (Not Merge)
# The append/concat/join operations should ideally never trigger remote
Expand Down Expand Up @@ -446,7 +468,8 @@ def reindex_builer(df, axis, old_labels, new_labels, **kwargs):
# Additionally this operation is often followed by an operation that
# assumes identical partitioning. Internally, we *may* change the
# partitioning during a map across a full axis.
return cls(self.map_across_full_axis(axis, func), new_index, new_columns)
new_data = self.map_across_full_axis(axis, func)
return cls(new_data, new_index, new_columns)

def reset_index(self, **kwargs):
cls = type(self)
Expand All @@ -461,7 +484,7 @@ def reset_index(self, **kwargs):
else:
# The copies here are to ensure that we do not give references to
# this object for the purposes of updates.
return cls(self.data.copy(), new_index, self.columns.copy())
return cls(self.data.copy(), new_index, self.columns.copy(), self.dtypes.copy())
# END Reindex/reset_index

# Transpose
Expand Down Expand Up @@ -548,45 +571,51 @@ def sum(self, **kwargs):

# Map partitions operations
# These operations are operations that apply a function to every partition.
def map_partitions(self, func):
def map_partitions(self, func, new_dtypes=None):
cls = type(self)
return cls(self.data.map_across_blocks(func), self.index, self.columns)
return cls(self.data.map_across_blocks(func), self.index, self.columns, new_dtypes)

def abs(self):
func = self._prepare_method(pandas.DataFrame.abs)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('float64') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def applymap(self, func):
remote_func = self._prepare_method(pandas.DataFrame.applymap, func=func)
return self.map_partitions(remote_func)

def isin(self, **kwargs):
func = self._prepare_method(pandas.DataFrame.isin, **kwargs)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('bool') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def isna(self):
func = self._prepare_method(pandas.DataFrame.isna)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('bool') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def isnull(self):
func = self._prepare_method(pandas.DataFrame.isnull)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('bool') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def negative(self, **kwargs):
func = self._prepare_method(pandas.DataFrame.__neg__, **kwargs)
return self.map_partitions(func)

def notna(self):
func = self._prepare_method(pandas.DataFrame.notna)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('bool') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def notnull(self):
func = self._prepare_method(pandas.DataFrame.notnull)
return self.map_partitions(func)
new_dtypes = pandas.Series([np.dtype('bool') for _ in self.columns], index=self.columns)
return self.map_partitions(func, new_dtypes=new_dtypes)

def round(self, **kwargs):
func = self._prepare_method(pandas.DataFrame.round, **kwargs)
return self.map_partitions(func)
return self.map_partitions(func, new_dtypes=self.dtypes.copy())
# END Map partitions operations

# Column/Row partitions reduce operations
Expand Down Expand Up @@ -773,7 +802,7 @@ def query_builder(df, **kwargs):
# Query removes rows, so we need to update the index
new_index = self.compute_index(0, new_data, True)

return cls(new_data, new_index, self.columns)
return cls(new_data, new_index, self.columns, self.dtypes)

def eval(self, expr, **kwargs):
cls = type(self)
Expand Down Expand Up @@ -821,7 +850,7 @@ def _cumulative_builder(self, func, **kwargs):
axis = kwargs.get("axis", 0)
func = self._prepare_method(func, **kwargs)
new_data = self.map_across_full_axis(axis, func)
return cls(new_data, self.index, self.columns)
return cls(new_data, self.index, self.columns, self.dtypes)

def cumsum(self, **kwargs):
return self._cumulative_builder(pandas.DataFrame.cumsum, **kwargs)
Expand Down Expand Up @@ -895,7 +924,7 @@ def mode(self, **kwargs):
# We build these intermediate objects to avoid depending directly on
# the underlying implementation.
final_data = cls(new_data, new_index, new_columns).map_across_full_axis(axis, lambda df: df.reindex(axis=axis, labels=final_labels))
return cls(final_data, new_index, new_columns)
return cls(final_data, new_index, new_columns, self.dtypes)

def fillna(self, **kwargs):
cls = type(self)
Expand Down Expand Up @@ -945,8 +974,8 @@ def rank(self, **kwargs):
new_columns = self.compute_index(1, new_data, True)
else:
new_columns = self.columns

return cls(new_data, self.index, new_columns)
new_dtypes = pandas.Series([np.float64 for _ in new_columns], index=new_columns)
return cls(new_data, self.index, new_columns, new_dtypes)

def diff(self, **kwargs):
cls = type(self)
Expand All @@ -970,40 +999,40 @@ def head(self, n):
# ensure that we extract the correct data on each node. The index
# on a transposed manager is already set to the correct value, so
# we need to only take the head of that instead of re-transposing.
result = cls(self.data.transpose().take(1, n).transpose(), self.index[:n], self.columns)
result = cls(self.data.transpose().take(1, n).transpose(), self.index[:n], self.columns, self.dtypes)
result._is_transposed = True
else:
result = cls(self.data.take(0, n), self.index[:n], self.columns)
result = cls(self.data.take(0, n), self.index[:n], self.columns, self.dtypes)
return result

def tail(self, n):
cls = type(self)
# See head for an explanation of the transposed behavior
if self._is_transposed:
result = cls(self.data.transpose().take(1, -n).transpose(), self.index[-n:], self.columns)
result = cls(self.data.transpose().take(1, -n).transpose(), self.index[-n:], self.columns, self.dtypes)
result._is_transposed = True
else:
result = cls(self.data.take(0, -n), self.index[-n:], self.columns)
result = cls(self.data.take(0, -n), self.index[-n:], self.columns, self.dtypes)
return result

def front(self, n):
cls = type(self)
# See head for an explanation of the transposed behavior
if self._is_transposed:
result = cls(self.data.transpose().take(0, n).transpose(), self.index, self.columns[:n])
result = cls(self.data.transpose().take(0, n).transpose(), self.index, self.columns[:n], self.dtypes[:n])
result._is_transposed = True
else:
result = cls(self.data.take(1, n), self.index, self.columns[:n])
result = cls(self.data.take(1, n), self.index, self.columns[:n], self.dtypes[:n])
return result

def back(self, n):
cls = type(self)
# See head for an explanation of the transposed behavior
if self._is_transposed:
result = cls(self.data.transpose().take(0, -n).transpose(), self.index, self.columns[-n:])
result = cls(self.data.transpose().take(0, -n).transpose(), self.index, self.columns[-n:], self.dtypes[-n:])
result._is_transposed = True
else:
result = cls(self.data.take(1, -n), self.index, self.columns[-n:])
result = cls(self.data.take(1, -n), self.index, self.columns[-n:], self.dtypes[-n:])
return result
# End Head/Tail/Front/Back

Expand All @@ -1026,13 +1055,14 @@ def to_pandas(self):
def from_pandas(cls, df, block_partitions_cls):
new_index = df.index
new_columns = df.columns
new_dtypes = df.dtypes

# Set the columns to RangeIndex for memory efficiency
df.index = pandas.RangeIndex(len(df.index))
df.columns = pandas.RangeIndex(len(df.columns))
new_data = block_partitions_cls.from_pandas(df)

return cls(new_data, new_index, new_columns)
return cls(new_data, new_index, new_columns, new_dtypes)

# __getitem__ methods
def getitem_single_key(self, key):
Expand Down Expand Up @@ -1062,7 +1092,8 @@ def getitem(df, internal_indices=[]):
# We can't just set the columns to key here because there may be
# multiple instances of a key.
new_columns = self.columns[numeric_indices]
return cls(result, self.index, new_columns)
new_dtypes = self.dtypes[numeric_indices]
return cls(result, self.index, new_columns, new_dtypes)

def getitem_row_array(self, key):
cls = type(self)
Expand All @@ -1076,7 +1107,7 @@ def getitem(df, internal_indices=[]):
# We can't just set the index to key here because there may be multiple
# instances of a key.
new_index = self.index[numeric_indices]
return cls(result, new_index, self.columns)
return cls(result, new_index, self.columns, self.dtypes)
# END __getitem__ methods

# __delitem__ and drop
Expand All @@ -1102,6 +1133,7 @@ def delitem(df, internal_indices=[]):

if columns is None:
new_columns = self.columns
new_dtypes = self.dtypes
else:
def delitem(df, internal_indices=[]):
return df.drop(columns=df.columns[internal_indices])
Expand All @@ -1111,7 +1143,9 @@ def delitem(df, internal_indices=[]):
# We can't use self.columns.drop with duplicate keys because in Pandas
# it throws an error.
new_columns = [self.columns[i] for i in range(len(self.columns)) if i not in numeric_indices]
return cls(new_data, new_index, new_columns)
new_dtypes = self.dtypes.drop(columns)

return cls(new_data, new_index, new_columns, new_dtypes)
# END __delitem__ and drop

# Insert
Expand All @@ -1129,9 +1163,54 @@ def insert(df, internal_indices=[]):

new_data = self.data.apply_func_to_select_indices_along_full_axis(0, insert, loc, keep_remaining=True)
new_columns = self.columns.insert(loc, column)
return cls(new_data, self.index, new_columns)

# Because a Pandas Series does not allow insert, we make a DataFrame
# and insert the new dtype that way.
temp_dtypes = pandas.DataFrame(self.dtypes).T
temp_dtypes.insert(loc, column, _get_dtype_from_object(value))
new_dtypes = temp_dtypes.iloc[0]

return cls(new_data, self.index, new_columns, new_dtypes)
# END Insert

# astype
# This method changes the types of select columns to the new dtype.
def astype(self, col_dtypes, errors='raise', **kwargs):
cls = type(self)

# Group the indicies to update together and create new dtypes series
dtype_indices = dict()
columns = col_dtypes.keys()
new_dtypes = self.dtypes.copy()

numeric_indices = list(self.columns.get_indexer_for(columns))

for i, column in enumerate(columns):
dtype = col_dtypes[column]
if dtype in dtype_indices.keys():
dtype_indices[dtype].append(numeric_indices[i])
else:
dtype_indices[dtype] = [numeric_indices[i]]
new_dtype = np.dtype(dtype)
if dtype != np.int32 and new_dtype == np.int32:
new_dtype = np.dtype('int64')
elif dtype != np.float32 and new_dtype == np.float32:
new_dtype = np.dtype('float64')
new_dtypes[column] = new_dtype

for dtype in dtype_indices.keys():

def astype(df, internal_indices=[]):
block_dtypes = dict()
for ind in internal_indices:
block_dtypes[df.columns[ind]]= dtype
return df.astype(block_dtypes)

new_data = self.data.apply_func_to_select_indices(0, astype, dtype_indices[dtype], keep_remaining=True)

return cls(self.data, self.index, self.columns, new_dtypes)
# END astype

# UDF (apply and agg) methods
# There is a wide range of behaviors that are supported, so a lot of the
# logic can get a bit convoluted.
Expand Down Expand Up @@ -1161,6 +1240,11 @@ def _post_process_apply(self, result_data, axis):
# this logic here.
if len(columns) == 0:
series_result = result_data.to_pandas(False)
if not axis and len(series_result) == len(self.columns) and len(index) != len(series_result):
index = self.columns
elif axis and len(series_result) == len(self.index) and len(index) != len(series_result):
index = self.index

series_result.index = index
return series_result

Expand Down Expand Up @@ -1220,7 +1304,7 @@ def callable_apply_builder(df, func, axis, index, *args, **kwargs):
result_data = self.map_across_full_axis(axis, func_prepared)

return self._post_process_apply(result_data, axis)
#END UDF
# END UDF

# Manual Partitioning methods (e.g. merge, groupby)
# These methods require some sort of manual partitioning due to their
Expand Down
Loading