Skip to content

Commit

Permalink
Fixes issues with NLP pipelines when data is not truncated (#316)
Browse files Browse the repository at this point in the history
* Fixes type-os in `TensorObject::read_element` as suggested in description in #305
* In cases where the length of the output results do not match the length of the data frame the `seq_ids` array is used to perform reduction of the rows. Such that if rows 5,6 & 7 of the output results map to row 5 in the dataframe, the max value of each row is stored in the response output.
* Add new method `MatxUtil::reduce_max` to perform reduction.

fixes #305

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #316
  • Loading branch information
dagardner-nv authored Aug 8, 2022
1 parent bb200ab commit 99d767f
Show file tree
Hide file tree
Showing 87 changed files with 714 additions and 71 deletions.
4 changes: 3 additions & 1 deletion morpheus/_lib/include/morpheus/objects/dev_mem_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ struct DevMemInfo
size_t offset;

/**
* TODO(Documentation)
* @brief Returns raw pointer to underlying buffer offset by the `offset`
*
* @return void*
*/
void *data() const;
};
Expand Down
10 changes: 7 additions & 3 deletions morpheus/_lib/include/morpheus/objects/tensor_object.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -476,8 +476,10 @@ struct TensorObject final
auto stride = this->get_stride();
auto shape = this->get_shape();

CHECK(shape.size() == N) << "Length of idx must match lengh of shape";

CHECK(std::transform_reduce(
stride.begin(), stride.end(), std::begin(idx), 0, std::logical_and<>(), std::less<>()))
shape.begin(), shape.end(), std::begin(idx), 1, std::logical_and<>(), std::greater<>()))
<< "Index is outsize of the bounds of the tensor. Index="
<< detail::array_to_str(std::begin(idx), std::begin(idx) + N)
<< ", Size=" << detail::array_to_str(shape.begin(), shape.end()) << "";
Expand All @@ -504,8 +506,10 @@ struct TensorObject final
auto stride = this->get_stride();
auto shape = this->get_shape();

CHECK(std::transform_reduce(
stride.begin(), stride.end(), std::begin(idx), 0, std::logical_and<>(), std::less<>()))
CHECK(shape.size() == N) << "Length of idx must match lengh of shape";

CHECK(
std::transform_reduce(shape.begin(), shape.end(), std::begin(idx), 1, std::logical_and<>(), std::less<>()))
<< "Index is outsize of the bounds of the tensor. Index="
<< detail::array_to_str(std::begin(idx), std::begin(idx) + N)
<< ", Size=" << detail::array_to_str(shape.begin(), shape.end()) << "";
Expand Down
20 changes: 20 additions & 0 deletions morpheus/_lib/include/morpheus/utilities/matx_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

#include <cstddef>
#include <memory>
#include <vector>

namespace morpheus {
struct MatxUtil
Expand Down Expand Up @@ -63,5 +64,24 @@ struct MatxUtil
const std::vector<TensorIndex> &stride,
double thresh_val,
bool by_row);

/**
* @brief Returns a buffer with `output_shape` containing the max value from values in `input` mapped according to
* `seq_ids`.
* Ex given a hypothetical input of:
*
* input = [5, 2, 8, 9, 8, 2, 1]
* seq_ids = [0, 0, 0, 1, 2, 3, 3]
*
* Will return:
* [8, 9, 8, 2]
* @return std::shared_ptr<rmm::device_buffer>
*/
static std::shared_ptr<rmm::device_buffer> reduce_max(const DevMemInfo &input,
const std::vector<int32_t> &seq_ids,
size_t seq_id_offset,
const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &input_stride,
const std::vector<int64_t> &output_shape);
};
} // namespace morpheus
17 changes: 1 addition & 16 deletions morpheus/_lib/src/messages/multi_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,14 @@ void MultiInferenceMessage::get_slice_impl(std::shared_ptr<MultiMessage> new_mes
std::size_t start,
std::size_t stop) const
{
CHECK(this->mess_count == this->count) << "At this time, mess_count and count must be the same for slicing";

auto sliced_message = DCHECK_NOTNULL(std::dynamic_pointer_cast<MultiInferenceMessage>(new_message));

sliced_message->offset = start;
sliced_message->count = stop - start;

// If we have more inference rows than message rows, we need to use the seq_ids to figure out the slicing. This
// will be slow and should be avoided at all costs
if (this->memory->has_input("seq_ids") && this->count != this->mess_count)
if (this->count != this->mess_count && this->memory->has_input("seq_ids"))
{
auto seq_ids = this->get_input("seq_ids");

Expand Down Expand Up @@ -146,26 +144,13 @@ std::size_t MultiInferenceMessageInterfaceProxy::count(MultiInferenceMessage &se
pybind11::object MultiInferenceMessageInterfaceProxy::get_input(MultiInferenceMessage &self, const std::string &name)
{
const auto &py_tensor = CupyUtil::tensor_to_cupy(self.get_input(name));

// // Need to get just our portion. TODO(MDD): THis should be handled in get_input
// py::object sliced = py_tensor[py::make_tuple(
// py::slice(py::int_(self.offset), py::int_(self.offset + self.count), py::none()),
// py::slice(py::none(), py::none(), py::none()))];

return py_tensor;
}

std::shared_ptr<MultiInferenceMessage> MultiInferenceMessageInterfaceProxy::get_slice(MultiInferenceMessage &self,
std::size_t start,
std::size_t stop)
{
// py::object seq_ids = CupyUtil::tensor_to_cupy(self.get_input("seq_ids"), m);

// int mess_start = seq_ids[py::make_tuple(start, 0)].attr("item")().cast<int>();
// int mess_stop = seq_ids[py::make_tuple(stop - 1, 0)].attr("item")().cast<int>() + 1;

// return std::make_shared<MultiInferenceMessage>(
// self.meta, mess_start, mess_stop - mess_start, self.memory, start, stop - start);
return self.get_slice(start, stop);
}
} // namespace morpheus
8 changes: 6 additions & 2 deletions morpheus/_lib/src/messages/multi_response.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,17 @@ void MultiResponseMessage::get_slice_impl(std::shared_ptr<MultiMessage> new_mess
std::size_t start,
std::size_t stop) const
{
CHECK(this->mess_count == this->count) << "At this time, mess_count and count must be the same for slicing";

auto sliced_message = DCHECK_NOTNULL(std::dynamic_pointer_cast<MultiResponseMessage>(new_message));

sliced_message->offset = start;
sliced_message->count = stop - start;

// Currently our output lengths should always match mess_count, and even if they didn't we wouldn't have any way to
// associate rows in the output with rows in the dataframe. Note on the input side we have the seq_ids array to
// but we don't have any equivelant for the output.
DCHECK(this->count == this->mess_count)
<< "Number of rows in response output does not match number of messages in DF";

// Pass onto the base
DerivedMultiMessage::get_slice_impl(new_message, start, stop);
}
Expand Down
97 changes: 79 additions & 18 deletions morpheus/_lib/src/stages/triton_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ void InferenceClientStage__check_triton_errors(triton::client::Error status,
throw std::runtime_error(err_msg);
}
}

template <typename IndexT>
inline IndexT get_elem_count(const std::vector<IndexT> &shape)
{
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
}
} // namespace

namespace morpheus {
Expand Down Expand Up @@ -87,43 +93,76 @@ InferenceClientStage::subscribe_fn_t InferenceClientStage::build_operator()

return input.subscribe(rxcpp::make_observer<sink_type_t>(
[this, &output, &client](sink_type_t x) {
auto reponse_memory = std::make_shared<ResponseMemory>(x->count);
// When our tensor lengths are longer than our dataframe we will need to use the seq_ids
// array to lookup how the values should map back into the dataframe
const bool needs_seq_ids = x->mess_count != x->count;
auto reponse_memory = std::make_shared<ResponseMemory>(x->mess_count);

// Create the output memory blocks
for (auto &model_output : m_model_outputs)
{
auto total_shape = model_output.shape;
std::vector<TensorIndex> total_shape{model_output.shape.begin(), model_output.shape.end()};

// First dimension will always end up being the number of rows
total_shape[0] = x->count;

auto elem_count = std::accumulate(total_shape.begin(), total_shape.end(), 1, std::multiplies<>());
// First dimension will always end up being the number of rows in the dataframe
total_shape[0] = static_cast<TensorIndex>(x->mess_count);
auto elem_count = get_elem_count(total_shape);

// Create the output memory
auto output_buffer = std::make_shared<rmm::device_buffer>(
elem_count * model_output.datatype.item_size(), rmm::cuda_stream_per_thread);

reponse_memory->tensors[model_output.mapped_name] = Tensor::create(
std::move(output_buffer),
model_output.datatype,
std::vector<TensorIndex>{static_cast<int>(total_shape[0]), static_cast<int>(total_shape[1])},
std::vector<TensorIndex>{},
0);
std::move(output_buffer), model_output.datatype, total_shape, std::vector<TensorIndex>{}, 0);
}

// This will be the final output of all mini-batches
auto response = std::make_shared<MultiResponseProbsMessage>(
x->meta, x->mess_offset, x->mess_count, std::move(reponse_memory), 0, reponse_memory->count);

std::unique_ptr<std::vector<int32_t>> host_seq_ids{nullptr};
if (needs_seq_ids)
{
// Take a copy of the sequence Ids allowing us to map rows in the response to rows in the dataframe
// The output tensors we store in `reponse_memory` will all be of the same length as the the
// dataframe. seq_ids has three columns, but we are only interested in the first column.
auto seq_ids = x->get_input("seq_ids");
const auto item_size = seq_ids.dtype().item_size();

host_seq_ids = std::make_unique<std::vector<int32_t>>(x->count);
SRF_CHECK_CUDA(cudaMemcpy2D(host_seq_ids->data(),
item_size,
seq_ids.data(),
seq_ids.stride(0) * item_size,
item_size,
host_seq_ids->size(),
cudaMemcpyDeviceToHost));
}

for (size_t i = 0; i < x->count; i += m_max_batch_size)
{
triton::client::InferInput *input1;

size_t start = i;
size_t stop = std::min(i + m_max_batch_size, x->count);

sink_type_t mini_batch_input = x->get_slice(start, stop);
source_type_t mini_batch_output = response->get_slice(start, stop);
sink_type_t mini_batch_input = x->get_slice(start, stop);

size_t out_start = start;
size_t out_stop = stop;
if (needs_seq_ids)
{
out_start = (*host_seq_ids)[out_start];
if (out_stop < host_seq_ids->size())
{
out_stop = (*host_seq_ids)[out_stop];
}
else
{
out_stop = x->mess_count;
}
}

source_type_t mini_batch_output = response->get_slice(out_start, out_stop);

// Iterate on the model inputs in case the model takes less than what tensors are available
std::vector<std::pair<std::shared_ptr<triton::client::InferInput>, std::vector<uint8_t>>>
Expand Down Expand Up @@ -199,21 +238,43 @@ InferenceClientStage::subscribe_fn_t InferenceClientStage::build_operator()
SRF_CHECK_CUDA(
cudaMemcpy(output_buffer->data(), output_ptr, output_ptr_size, cudaMemcpyHostToDevice));

if (needs_seq_ids && output_shape[0] != mini_batch_output->count)
{
// Since we are working with slices of both the input and the output, the seq_ids have
// already been applied to the output's start & stop, so we only need to reduce the
// response tensort when the size doesn't match our output
std::vector<int64_t> mapped_output_shape{output_shape};
mapped_output_shape[0] = mini_batch_output->count;

size_t element_count = get_elem_count(output_shape);

// Triton results are always in row-major as required by the KServe protocol
// https://github.com/kserve/kserve/blob/master/docs/predict-api/v2/required_api.md#tensor-data
std::vector<int64_t> stride{output_shape[1], 1};
output_buffer = MatxUtil::reduce_max(
DevMemInfo{element_count, model_output.datatype.type_id(), output_buffer, 0},
*host_seq_ids,
mini_batch_input->offset,
output_shape,
stride,
mapped_output_shape);
output_shape = std::move(mapped_output_shape);
}

// If we need to do logits, do that here
if (m_needs_logits)
{
size_t element_count =
std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<>());
output_buffer = MatxUtil::logits(
size_t element_count = get_elem_count(output_shape);
output_buffer = MatxUtil::logits(
DevMemInfo{element_count, model_output.datatype.type_id(), output_buffer, 0});
}

mini_batch_output->set_output(
model_output.mapped_name,
Tensor::create(std::move(output_buffer),
model_output.datatype,
std::vector<TensorIndex>{static_cast<int>(output_shape[0]),
static_cast<int>(output_shape[1])},
std::vector<TensorIndex>{static_cast<TensorIndex>(output_shape[0]),
static_cast<TensorIndex>(output_shape[1])},
std::vector<TensorIndex>{},
0));
}
Expand Down
92 changes: 92 additions & 0 deletions morpheus/_lib/src/utilities/matx_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,45 @@ namespace morpheus {
}
};

struct MatxUtil__MatxReduceMax {
matx::index_t num_input_rows;
matx::index_t num_cols;
std::vector<matx::index_t> input_stride;
matx::index_t num_output_rows;
void *input_data;
void *output_data;
rmm::cuda_stream_view stream;

template<typename InputT, std::enable_if_t<!cudf::is_floating_point<InputT>()> * = nullptr>
void operator()(std::size_t start, std::size_t stop, int32_t output_idx) {
throw std::invalid_argument("Unsupported conversion");
}

template<typename InputT, std::enable_if_t<cudf::is_floating_point<InputT>()> * = nullptr>
void operator()(std::size_t start, std::size_t stop, int32_t output_idx) {
auto input_count = stop - start;
matx::tensorShape_t<2> input_shape({static_cast<matx::index_t>(input_count), num_cols});
matx::tensorShape_t<1> output_shape({num_cols});

matx::index_t output_stride[2] = {input_stride[0], input_stride[1]};
if (output_stride[0] == 1)
{
output_stride[1] = num_output_rows;
}

auto input_ptr = static_cast<InputT *>(input_data) + (start * input_stride[0]);
auto output_ptr = static_cast<InputT *>(output_data) + (output_idx * output_stride[0]);

matx::tensor_t<InputT, 2> input_tensor(input_ptr, input_shape, {input_stride[0], input_stride[1]});
matx::tensor_t<InputT, 1> output_tensor(output_ptr, output_shape, {output_stride[1]});

// We need to transpose the input such that rmax will reduce the rows
// Matx performs reductions over the innermost dimensions.
// see https://nvidia.github.io/MatX/api/reduce.html
matx::rmax(output_tensor, input_tensor.Permute({1, 0}), stream.value());
}
};

// Component public implementations
// ************ MatxUtil************************* //
std::shared_ptr<rmm::device_buffer> MatxUtil::cast(const DevMemInfo &input, TypeId output_type) {
Expand Down Expand Up @@ -337,4 +376,57 @@ namespace morpheus {

return output;
}

std::shared_ptr<rmm::device_buffer>
MatxUtil::reduce_max(const DevMemInfo &input,
const std::vector<int32_t> &seq_ids,
size_t seq_id_offset,
const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &input_stride,
const std::vector<int64_t> &output_shape)
{
auto dtype = DType(input.type_id);
auto elem_size = dtype.item_size();
auto cudf_type = cudf::data_type{dtype.cudf_type_id()};
auto num_input_rows = input_shape[0];
auto num_input_cols = input_shape[1];

std::vector<matx::index_t>matx_stride{input_stride[0], input_stride[1]};
std::size_t output_element_count = output_shape[0] * output_shape[1];
std::size_t output_buff_size = elem_size * output_element_count;

DCHECK(output_element_count <= input.element_count) << "Output buffer size should be less than or equal to the input";
DCHECK(num_input_cols == output_shape[1]) << "Number of input and output columns must match";

auto output = std::make_shared<rmm::device_buffer>(output_buff_size,
input.buffer->stream(),
input.buffer->memory_resource());

MatxUtil__MatxReduceMax matx_reduce_max{num_input_rows, num_input_cols, matx_stride, output_shape[0], input.data(), output->data(), output->stream()};

std::size_t start = 0;
auto output_offset = seq_ids[seq_id_offset];
for (std::size_t i=0; i < num_input_rows; ++i)
{
auto idx = seq_ids[i+seq_id_offset];
if (idx != seq_ids[start+seq_id_offset])
{
cudf::type_dispatcher(cudf_type,
matx_reduce_max,
start,
i,
seq_ids[start+seq_id_offset]-output_offset);
start = i;
}
}

cudf::type_dispatcher(cudf_type,
matx_reduce_max,
start,
num_input_rows,
seq_ids[start+seq_id_offset]-output_offset);

srf::enqueue_stream_sync_event(output->stream()).get();
return output;
}
}
2 changes: 2 additions & 0 deletions morpheus/_lib/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ list(APPEND CMAKE_MESSAGE_CONTEXT "tests")
add_executable(test_libmorpheus
# test_cuda.cu
test_main.cpp
test_matx_util.cpp
test_morpheus.cpp
test_multi_slices.cpp
test_tensor.cpp
test_type_util_detail.cpp
Expand Down
Loading

0 comments on commit 99d767f

Please sign in to comment.