diff --git a/python/ray/dataframe/index_metadata.py b/python/ray/dataframe/index_metadata.py index 50ba251a8ff9..63d96202ff9c 100644 --- a/python/ray/dataframe/index_metadata.py +++ b/python/ray/dataframe/index_metadata.py @@ -103,8 +103,6 @@ def _get_index(self): _IndexMetadata constructor for more details) """ if isinstance(self._coord_df_cache, ray.local_scheduler.ObjectID): - if self._index_cache is None: - self._index_cache = pd.RangeIndex(len(self)) return self._index_cache else: return self._coord_df_cache.index @@ -128,6 +126,37 @@ def _set_index(self, new_index): index = property(_get_index, _set_index) + def _get_index_cache(self): + """Get the cached Index object, which may sometimes be an OID. + + This will ray.get the Index object out of the Ray store lazily, such + that it is not grabbed until it is needed in the driver. This layer of + abstraction is important for allowing this object to be instantiated + with a remote Index object. + + Returns: + The Index object in _index_cache. + """ + if self._index_cache_validator is None: + self._index_cache_validator = pd.RangeIndex(len(self)) + elif isinstance(self._index_cache_validator, + ray.local_scheduler.ObjectID): + self._index_cache_validator = ray.get(self._index_cache_validator) + + return self._index_cache_validator + + def _set_index_cache(self, new_index): + """Sets the new index cache. + + Args: + new_index: The Index to set the _index_cache to. + """ + self._index_cache_validator = new_index + + # _index_cache_validator is an extra layer of abstraction to allow the + # cache to accept ObjectIDs and ray.get them when needed. + _index_cache = property(_get_index_cache, _set_index_cache) + def coords_of(self, key): """Returns the coordinates (partition, index_within_partition) of the provided key in the index. Can be called on its own or implicitly diff --git a/python/ray/dataframe/utils.py b/python/ray/dataframe/utils.py index 4c08490b9b6a..78d728f69023 100644 --- a/python/ray/dataframe/utils.py +++ b/python/ray/dataframe/utils.py @@ -112,7 +112,6 @@ def to_pandas(df): else: pd_df = pd.concat(ray.get(df._col_partitions), axis=1) - print(df.columns) pd_df.index = df.index pd_df.columns = df.columns return pd_df