Skip to content

Commit

Permalink
Add struct column support to cudf_helpers (nv-morpheus#1538)
Browse files Browse the repository at this point in the history
+ Update `make_table_from_table_info_data` in `cudf_helpers` to build a `table_metadata` object used to add support for struct columns.
+ Update `DatasetManager.df_equal` to support list columns. Now uses pandas `equals` if `val_to_check` is a DataFrame of Series.
+ Add new jsonlines file to use for building test DataFrame that includes struct and list columns.
+ Update MessageMeta and ControlMessage tests to use new test data.

Closes nv-morpheus#1527 

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - Eli Fajardo (https://github.com/efajardo-nv)
  - Michael Demoret (https://github.com/mdemoret-nv)
  - Anuradha Karuppiah (https://github.com/AnuradhaKaruppiah)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: nv-morpheus#1538
  • Loading branch information
efajardo-nv authored Jun 18, 2024
1 parent 0b284b5 commit 46f842d
Show file tree
Hide file tree
Showing 8 changed files with 464 additions and 46 deletions.
84 changes: 82 additions & 2 deletions morpheus/_lib/cudf_helpers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
# limitations under the License.

import cudf
from cudf.core.dtypes import StructDtype

from libcpp.string cimport string
from libcpp.utility cimport move
from libcpp.vector cimport vector

from cudf._lib.column cimport Column
from cudf._lib.cpp.io.types cimport column_name_info
from cudf._lib.cpp.io.types cimport table_metadata
from cudf._lib.cpp.io.types cimport table_with_metadata
from cudf._lib.cpp.table.table_view cimport table_view
from cudf._lib.cpp.types cimport size_type
Expand Down Expand Up @@ -52,10 +55,27 @@ cdef public api:

data, index = data_from_unique_ptr(move(table.tbl), column_names=column_names, index_names=index_names)

return cudf.DataFrame._from_data(data, index)
df = cudf.DataFrame._from_data(data, index)

# Update the struct field names after the DataFrame is created
update_struct_field_names(df, table.metadata.schema_info)

return df

object make_table_from_table_info_data(TableInfoData table_info, object owner):

cdef table_metadata tbl_meta

num_index_cols_meta = 0
cdef column_name_info child_info
for i, name in enumerate(owner._column_names, num_index_cols_meta):
child_info.name = name.encode()
tbl_meta.schema_info.push_back(child_info)
_set_col_children_metadata(
owner[name]._column,
tbl_meta.schema_info[i]
)

index_names = None

if (table_info.index_names.size() > 0):
Expand Down Expand Up @@ -89,7 +109,11 @@ cdef public api:
import traceback
print("error while converting libcudf table to cudf dataframe:", traceback.format_exc())

return cudf.DataFrame._from_data(data, index)
df = cudf.DataFrame._from_data(data, index)

update_struct_field_names(df, tbl_meta.schema_info)

return df


TableInfoData make_table_info_data_from_table(object table):
Expand Down Expand Up @@ -173,3 +197,59 @@ cdef public api:
source_column_idx += 1

return dict(zip(column_names, data_columns)), index

cdef _set_col_children_metadata(Column col,
column_name_info& col_meta):
cdef column_name_info child_info
if isinstance(col.dtype, cudf.StructDtype):
for i, (child_col, name) in enumerate(
zip(col.children, list(col.dtype.fields))
):
child_info.name = name.encode()
col_meta.children.push_back(child_info)
_set_col_children_metadata(
child_col, col_meta.children[i]
)
elif isinstance(col.dtype, cudf.ListDtype):
for i, child_col in enumerate(col.children):
col_meta.children.push_back(child_info)
_set_col_children_metadata(
child_col, col_meta.children[i]
)
else:
return

cdef update_struct_field_names(
table,
vector[column_name_info]& schema_info
):
for i, (name, col) in enumerate(table._data.items()):
table._data[name] = update_column_struct_field_names(
col, schema_info[i]
)


cdef Column update_column_struct_field_names(
Column col,
column_name_info& info
):
cdef vector[string] field_names

if col.dtype != "object" and col.children:
children = list(col.children)
for i, child in enumerate(children):
children[i] = update_column_struct_field_names(
child,
info.children[i]
)
col.set_base_children(tuple(children))

if isinstance(col.dtype, StructDtype):
field_names.reserve(len(col.base_children))
for i in range(info.children.size()):
field_names.push_back(info.children[i].name)
col = col._rename_fields(
field_names
)

return col
2 changes: 2 additions & 0 deletions morpheus/_lib/cudf_helpers/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations
import morpheus._lib.cudf_helpers
import typing
from cudf.core.dtypes import StructDtype
import cudf

__all__ = [
"StructDtype",
"cudf"
]

Expand Down
181 changes: 181 additions & 0 deletions morpheus/messages/message_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import typing
import warnings

import cupy as cp
import numpy as np
import pandas as pd

import cudf
Expand Down Expand Up @@ -90,6 +92,18 @@ def __init__(self, df: DataFrameType) -> None:
self._mutex = threading.RLock()
self._df = df

def _get_col_indexers(self, df, columns: typing.Union[None, str, typing.List[str]] = None):

if (columns is None):
columns = df.columns.to_list()
elif (isinstance(columns, str)):
# Convert a single string into a list so all versions return tables, not series
columns = [columns]

column_indexer = df.columns.get_indexer_for(columns)

return column_indexer

@property
def df(self) -> DataFrameType:
msg = ("Warning the df property returns a copy, please use the copy_dataframe method or the mutable_dataframe "
Expand Down Expand Up @@ -201,6 +215,173 @@ def get_meta_range(self,
# If its a str or list, this is the same
return self._df.loc[idx, columns]

@typing.overload
def get_data(self) -> cudf.DataFrame:
...

@typing.overload
def get_data(self, columns: str) -> cudf.Series:
...

@typing.overload
def get_data(self, columns: typing.List[str]) -> cudf.DataFrame:
...

def get_data(self, columns: typing.Union[None, str, typing.List[str]] = None):
"""
Return column values from the underlying DataFrame.
Parameters
----------
columns : typing.Union[None, str, typing.List[str]]
Input column names. Returns all columns if `None` is specified. When a string is passed, a `Series` is
returned. Otherwise, a `Dataframe` is returned.
Returns
-------
Series or Dataframe
Column values from the dataframe.
"""

with self.mutable_dataframe() as df:
column_indexer = self._get_col_indexers(df, columns=columns)

if (-1 in column_indexer):
missing_columns = [columns[i] for i, index_value in enumerate(column_indexer) if index_value == -1]
raise KeyError(f"Requested columns {missing_columns} does not exist in the dataframe")

if (isinstance(columns, str) and len(column_indexer) == 1):
# Make sure to return a series for a single column
column_indexer = column_indexer[0]

return df.iloc[:, column_indexer]

def set_data(self, columns: typing.Union[None, str, typing.List[str]], value):
"""
Set column values to the underlying DataFrame.
Parameters
----------
columns : typing.Union[None, str, typing.List[str]]
Input column names. Sets the value for the corresponding column names. If `None` is specified, all columns
will be used. If the column does not exist, a new one will be created.
value : Any
Value to apply to the specified columns. If a single value is passed, it will be broadcast to all rows. If a
`Series` or `Dataframe` is passed, rows will be matched by index.
"""

# Get exclusive access to the dataframe
with self.mutable_dataframe() as df:
# First try to set the values on just our slice if the columns exist
column_indexer = self._get_col_indexers(df, columns=columns)

# Check if the value is a cupy array and we have a pandas dataframe, convert to numpy
if (isinstance(value, cp.ndarray) and isinstance(df, pd.DataFrame)):
value = value.get()

# Check to see if we are adding a column. If so, we need to use df.loc instead of df.iloc
if (-1 not in column_indexer):

# If we only have one column, convert it to a series (broadcasts work with more types on a series)
if (len(column_indexer) == 1):
column_indexer = column_indexer[0]

try:
# Now update the slice
df.iloc[:, column_indexer] = value
except (ValueError, TypeError):
# Try this as a fallback. Works better for strings. See issue #286
df[columns].iloc[:] = value

else:
# Columns should never be empty if we get here
assert columns is not None

# cudf is really bad at adding new columns
if (isinstance(df, cudf.DataFrame)):

# TODO(morpheus#1487): This logic no longer works in CUDF 24.04.
# We should find a way to reinable the no-dropped-index path as
# that should be more performant than dropping the index.
# # saved_index = None

# # # Check to see if we can use slices
# # if (not (df.index.is_unique and
# # (df.index.is_monotonic_increasing or df.index.is_monotonic_decreasing))):
# # # Save the index and reset
# # saved_index = df.index
# # df.reset_index(drop=True, inplace=True)

# # # Perform the update via slices
# # df.loc[df.index[row_indexer], columns] = value

# # # Reset the index if we changed it
# # if (saved_index is not None):
# # df.set_index(saved_index, inplace=True)

saved_index = df.index
df.reset_index(drop=True, inplace=True)
df.loc[df.index[:], columns] = value
df.set_index(saved_index, inplace=True)
else:
# Now set the slice
df.loc[:, columns] = value

def get_slice(self, start, stop):
"""
Returns a new MessageMeta with only the rows specified by start/stop.
Parameters
----------
start : int
Start offset address.
stop : int
Stop offset address.
Returns
-------
`MessageMeta`
A new `MessageMeta` with sliced offset and count.
"""

with self.mutable_dataframe() as df:
return MessageMeta(df.iloc[start:stop])

def _ranges_to_mask(self, df, ranges):
if isinstance(df, cudf.DataFrame):
zeros_fn = cp.zeros
else:
zeros_fn = np.zeros

mask = zeros_fn(len(df), bool)

for range_ in ranges:
mask[range_[0]:range_[1]] = True

return mask

def copy_ranges(self, ranges: typing.List[typing.Tuple[int, int]]):
"""
Perform a copy of the current message instance for the given `ranges` of rows.
Parameters
----------
ranges : typing.List[typing.Tuple[int, int]]
Rows to include in the copy in the form of `[(`start_row`, `stop_row`),...]`
The `stop_row` isn't included. For example to copy rows 1-2 & 5-7 `ranges=[(1, 3), (5, 8)]`
Returns
-------
`MessageMeta`
A new `MessageMeta` with only the rows specified by `ranges`.
"""

with self.mutable_dataframe() as df:
mask = self._ranges_to_mask(df, ranges=ranges)
return MessageMeta(df.loc[mask, :])


@dataclasses.dataclass(init=False)
class UserMessageMeta(MessageMeta, cpp_class=None):
Expand Down
3 changes: 3 additions & 0 deletions tests/_utils/dataset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def df_equal(cls, df_to_check: typing.Union[pd.DataFrame, cdf.DataFrame], val_to
else:
val_to_check = cls._value_as_pandas(val_to_check, assert_is_pandas=False)

if (isinstance(val_to_check, (pd.DataFrame, pd.Series))):
return df_to_check.equals(val_to_check)

bool_df = df_to_check == val_to_check

return bool(bool_df.all(axis=None))
Expand Down
10 changes: 9 additions & 1 deletion tests/llm/nodes/test_llm_retriever_node_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,15 @@ def test_pipeline_with_milvus(config: Config,
values = {'prompt': ["prompt1", "prompt2"]}
input_df = cudf.DataFrame(values)
expected_df = input_df.copy(deep=True)
expected_df["response"] = [[{'0': 27, '1': 2}, {'0': 26, '1': 1}], [{'0': 27, '1': 2}, {'0': 26, '1': 1}]]
expected_df["response"] = [[{
'age': 27, 'id': 2
}, {
'age': 26, 'id': 1
}], [{
'age': 27, 'id': 2
}, {
'age': 26, 'id': 1
}]]

task_payload = {"task_type": "llm_engine", "task_dict": {"input_keys": sorted(values.keys())}}

Expand Down
16 changes: 5 additions & 11 deletions tests/messages/test_control_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@
import cupy as cp
import pytest

import cudf

from _utils.dataset_manager import DatasetManager
from morpheus import messages
# pylint: disable=morpheus-incorrect-lib-from-import
from morpheus.messages import TensorMemory

# pylint: disable=unsupported-membership-test
Expand Down Expand Up @@ -206,21 +204,17 @@ def test_control_message_set():
assert (control_message.has_task("load"))


def test_control_message_set_and_get_payload():
df = cudf.DataFrame({
'col1': [1, 2, 3, 4, 5],
'col2': [1.1, 2.2, 3.3, 4.4, 5.5],
'col3': ['a', 'b', 'c', 'd', 'e'],
'col4': [True, False, True, False, True]
})
def test_control_message_set_and_get_payload(dataset: DatasetManager):
df = dataset["test_dataframe.jsonlines"]

msg = messages.ControlMessage()
payload = messages.MessageMeta(df)
msg.payload(payload)

payload2 = msg.payload()
assert payload2 is not None
assert payload.df == payload2.df

DatasetManager.assert_df_equal(payload.df, payload2.df)


@pytest.mark.usefixtures("config_only_cpp")
Expand Down
Loading

0 comments on commit 46f842d

Please sign in to comment.