Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for C++ impl for DeserializeStage and add missing get_info overloads to SlicedMessageMeta #1749

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions morpheus/_lib/include/morpheus/messages/meta.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ class MORPHEUS_EXPORT SlicedMessageMeta : public MessageMeta

TableInfo get_info() const override;

TableInfo get_info(const std::string& col_name) const override;

TableInfo get_info(const std::vector<std::string>& column_names) const override;

MutableTableInfo get_mutable_info() const override;

std::optional<std::string> ensure_sliceable_index() override;
Expand Down
14 changes: 14 additions & 0 deletions morpheus/_lib/src/messages/meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,20 @@ TableInfo SlicedMessageMeta::get_info() const
return this->m_data->get_info().get_slice(m_start, m_stop, m_column_names);
}

TableInfo SlicedMessageMeta::get_info(const std::string& col_name) const
{
auto full_info = this->m_data->get_info();

return full_info.get_slice(m_start, m_stop, {col_name});
}

TableInfo SlicedMessageMeta::get_info(const std::vector<std::string>& column_names) const
{
auto full_info = this->m_data->get_info();

return full_info.get_slice(m_start, m_stop, column_names);
}

MutableTableInfo SlicedMessageMeta::get_mutable_info() const
{
return this->m_data->get_mutable_info().get_slice(m_start, m_stop, m_column_names);
Expand Down
25 changes: 12 additions & 13 deletions morpheus/_lib/src/stages/preprocess_fil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,32 +229,32 @@ std::shared_ptr<MultiInferenceMessage> PreprocessFILStage<MultiMessage, MultiInf
template <>
std::shared_ptr<ControlMessage> PreprocessFILStage<ControlMessage, ControlMessage>::on_control_message(
std::shared_ptr<ControlMessage> x)

{
auto num_rows = x->payload()->get_info().num_rows();
auto df_meta = this->fix_bad_columns(x);
const auto num_rows = df_meta.num_rows();

auto packed_data =
std::make_shared<rmm::device_buffer>(m_fea_cols.size() * num_rows * sizeof(float), rmm::cuda_stream_per_thread);
auto df_meta = this->fix_bad_columns(x);

for (size_t i = 0; i < df_meta.num_columns(); ++i)
{
auto curr_col = df_meta.get_column(i);

auto curr_ptr = static_cast<float*>(packed_data->data()) + i * df_meta.num_rows();
auto curr_ptr = static_cast<float*>(packed_data->data()) + i * num_rows;

// Check if we are something other than float
if (curr_col.type().id() != cudf::type_id::FLOAT32)
{
auto float_data = cudf::cast(curr_col, cudf::data_type(cudf::type_id::FLOAT32))->release();

// Do the copy here before it goes out of scope
MRC_CHECK_CUDA(cudaMemcpy(
curr_ptr, float_data.data->data(), df_meta.num_rows() * sizeof(float), cudaMemcpyDeviceToDevice));
MRC_CHECK_CUDA(
cudaMemcpy(curr_ptr, float_data.data->data(), num_rows * sizeof(float), cudaMemcpyDeviceToDevice));
}
else
{
MRC_CHECK_CUDA(cudaMemcpy(curr_ptr,
curr_col.template data<float>(),
df_meta.num_rows() * sizeof(float),
cudaMemcpyDeviceToDevice));
MRC_CHECK_CUDA(cudaMemcpy(
curr_ptr, curr_col.template data<float>(), num_rows * sizeof(float), cudaMemcpyDeviceToDevice));
}
}

Expand All @@ -279,10 +279,9 @@ std::shared_ptr<ControlMessage> PreprocessFILStage<ControlMessage, ControlMessag
auto memory = std::make_shared<TensorMemory>(num_rows);
memory->set_tensor("input__0", std::move(input__0));
memory->set_tensor("seq_ids", std::move(seq_ids));
auto next = x;
next->tensors(memory);
x->tensors(memory);

return next;
return x;
}

template class PreprocessFILStage<MultiMessage, MultiInferenceMessage>;
Expand Down
30 changes: 27 additions & 3 deletions morpheus/_lib/tests/messages/test_sliced_message_meta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,28 @@
#include <filesystem> // for std::filesystem::path
#include <memory> // for shared_ptr
#include <utility> // for move
#include <vector>

using namespace morpheus;

using TestSlicedMessageMeta = morpheus::test::TestMessages; // NOLINT(readability-identifier-naming)

TEST_F(TestSlicedMessageMeta, TestCount)
std::shared_ptr<MessageMeta> create_test_meta()
{
// Test for issue #970
auto test_data_dir = test::get_morpheus_root() / "tests/tests_data";

auto input_file{test_data_dir / "filter_probs.csv"};

auto table = load_table_from_file(input_file);
auto index_col_count = prepare_df_index(table);

auto meta = MessageMeta::create_from_cpp(std::move(table), index_col_count);
return MessageMeta::create_from_cpp(std::move(table), index_col_count);
}

TEST_F(TestSlicedMessageMeta, TestCount)
{
// Test for issue #970
auto meta = create_test_meta();
EXPECT_EQ(meta->count(), 20);

SlicedMessageMeta sliced_meta(meta, 5, 15);
Expand All @@ -60,3 +66,21 @@ TEST_F(TestSlicedMessageMeta, TestCount)
EXPECT_EQ(p_meta->count(), 10);
EXPECT_EQ(p_meta->get_info().num_rows(), p_meta->count());
}

TEST_F(TestSlicedMessageMeta, TestGetInfo)
{
// Test for bug #1747 where get_info() wasn't being overridden for column overloads
auto meta = create_test_meta();
std::unique_ptr<MessageMeta> sliced_meta = std::make_unique<SlicedMessageMeta>(meta, 5, 15);

const auto num_rows = sliced_meta->count();

pybind11::gil_scoped_release no_gil;
EXPECT_EQ(num_rows, sliced_meta->get_info().num_rows());

std::string column_name("v1");
EXPECT_EQ(num_rows, sliced_meta->get_info(column_name).num_rows());

std::vector<std::string> column_names{"v1", "v2"};
EXPECT_EQ(num_rows, sliced_meta->get_info(column_names).num_rows());
}
27 changes: 19 additions & 8 deletions morpheus/stages/preprocess/deserialize_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import mrc

import morpheus._lib.stages as _stages
from morpheus.cli.register_stage import register_stage
from morpheus.config import Config
from morpheus.config import PipelineModes
Expand Down Expand Up @@ -63,8 +62,7 @@ def __init__(self,
c: Config,
*,
ensure_sliceable_index: bool = True,
message_type: typing.Union[typing.Literal[MultiMessage],
typing.Literal[ControlMessage]] = MultiMessage,
message_type: typing.Union[MultiMessage, ControlMessage] = MultiMessage,
dagardner-nv marked this conversation as resolved.
Show resolved Hide resolved
task_type: str = None,
task_payload: dict = None):
super().__init__(c)
Expand All @@ -81,10 +79,10 @@ def __init__(self,
self._task_type = task_type
self._task_payload = task_payload

if (self._message_type == ControlMessage):
if (self._message_type is ControlMessage):
if ((self._task_type is None) != (self._task_payload is None)):
raise ValueError("Both `task_type` and `task_payload` must be specified if either is specified.")
elif (self._message_type == MultiMessage):
elif (self._message_type is MultiMessage):
if (self._task_type is not None or self._task_payload is not None):
raise ValueError("Cannot specify `task_type` or `task_payload` for non-control messages.")
else:
Expand Down Expand Up @@ -113,14 +111,27 @@ def accepted_types(self) -> typing.Tuple:

def supports_cpp_node(self):
# Enable support by default
return False
return True

def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(self._message_type)

def _build_single(self, builder: mrc.Builder, input_node: mrc.SegmentObject) -> mrc.SegmentObject:
if (self.supports_cpp_node()):
out_node = _stages.DeserializeStage(builder, self.unique_name, self._batch_size)
if (self._build_cpp_node()):
import morpheus._lib.stages as _stages
if (self._message_type is ControlMessage):
out_node = _stages.DeserializeControlMessageStage(builder,
self.unique_name,
batch_size=self._batch_size,
ensure_sliceable_index=self._ensure_sliceable_index,
task_type=self._task_type,
task_payload=self._task_payload)
else:
out_node = _stages.DeserializeMultiMessageStage(builder,
self.unique_name,
batch_size=self._batch_size,
ensure_sliceable_index=self._ensure_sliceable_index)

builder.make_edge(input_node, out_node)
else:
module_loader = DeserializeLoaderFactory.get_instance(module_name=f"deserialize_{self.unique_name}",
Expand Down
Loading