Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mdspan/mdarray template functions and utilities #601

Merged
merged 25 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
119d75b
working through is_device_mdspan
divyegala Mar 30, 2022
c205a40
Merge remote-tracking branch 'upstream/branch-22.06' into imp-22.06-l…
divyegala Mar 30, 2022
b6799f5
specializations for is_mdspan
divyegala Mar 30, 2022
3451b07
flatten and tests
divyegala Mar 30, 2022
1eb7565
checking for derived mdspan
divyegala Apr 1, 2022
fb2de54
working flatten for mdarray and mdspan
divyegala Apr 6, 2022
6539936
add copyright
divyegala Apr 6, 2022
eb400c6
working through reshape
divyegala Apr 19, 2022
27c5c19
finishing up host/device flatten with tests
divyegala Apr 19, 2022
ff282f3
working host reshape with tests
divyegala Apr 19, 2022
b8aec4c
working reshape for device arrays
divyegala Apr 19, 2022
ceb7dd4
adding docstrings
divyegala Apr 19, 2022
fa88e23
Apply suggestions from code review
divyegala Apr 20, 2022
e3f52ce
static extents tests, some variable renaming
divyegala Apr 21, 2022
95e66b9
Merge remote-tracking branch 'origin/fea-22.06-mdspan_utils' into fea…
divyegala Apr 21, 2022
de6bea6
merging upstream
divyegala Apr 21, 2022
4fc4a5d
working array_interface
divyegala Apr 21, 2022
f87d02f
small fix to docstring
divyegala Apr 21, 2022
30d878c
removing unneeded aliases from array_interface
divyegala Apr 21, 2022
65e54a5
using array_interface with CRTP
divyegala Apr 27, 2022
8659b39
Merge remote-tracking branch 'upstream/branch-22.06' into fea-22.06-m…
divyegala Apr 27, 2022
66ade4b
remove implict operator converters of mdspan from mdarray
divyegala Apr 27, 2022
a407aee
remove flatten overloads
divyegala Apr 27, 2022
7848219
explicit cstddef include
divyegala Apr 28, 2022
30d0b8b
stddef.h
divyegala Apr 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working host reshape with tests
  • Loading branch information
divyegala committed Apr 19, 2022
commit ff282f3d16f2796db9dd690f79a1175879662c68
82 changes: 45 additions & 37 deletions cpp/include/raft/mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <rmm/cuda_stream_view.hpp>

namespace raft {

template <size_t... ExtentsPack>
using extents = std::experimental::extents<ExtentsPack...>;

/**
* @\brief C-Contiguous layout for mdarray and mdspan. Implies row-major and contiguous memory.
*/
Expand Down Expand Up @@ -814,39 +818,41 @@ auto reshape(host_mdspan_type h_mds, std::experimental::extents<Extents...> new_
{
RAFT_EXPECTS(h_mds.is_contiguous(), "Input must be contiguous.");

if (new_shape == h_mds.extents()) {
return h_mds;
} else if (new_shape.rank(1)) {
auto new_size = new_shape.extent(0);
RAFT_EXPECTS(new_size <= h_mds.size(),
"Cannot reshape array of size %ul into %ul",
h_mds.size(),
new_size());

if (new_size == 1) {
return make_host_scalar_view<typename host_mdspan_type::element_type>(h_mds.data());
} else {
return make_host_vector_view<typename host_mdspan_type::element_type,
typename host_mdspan_type::layout_type>(h_mds.data(), new_size);
}
} else if (new_shape.rank(2)) {
auto new_size = new_shape.extent(0) * new_shape.extent(1);
RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch");

return make_host_matrix_view<typename host_mdspan_type::element_type,
typename host_mdspan_type::layout_type>(
h_mds.data(), new_shape.extent(0), new_shape.extent(1));
} else {
size_t new_size = 1;
for (size_t i = 0; i < new_shape.rank(); ++i) {
new_size *= new_shape.extent(i);
}
RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch");

return detail::stdex::mdspan<typename host_mdspan_type::element_type,
decltype(new_shape),
typename host_mdspan_type::layout_type>(h_mds.data(), new_shape);
// if (new_shape == h_mds.extents()) {
// return h_mds;
// } else if (new_shape.rank() == 1) {
// auto new_size = new_shape.extent(0);
// RAFT_EXPECTS(new_size <= h_mds.size(),
// "Cannot reshape array of size %ul into %ul",
// h_mds.size(),
// new_size());

// if (new_size == 1) {
// return make_host_scalar_view<typename host_mdspan_type::element_type>(h_mds.data());
// } else {
// return make_host_vector_view<typename host_mdspan_type::element_type,
// typename host_mdspan_type::layout_type>(h_mds.data(),
// new_size);
// }
// } else if (new_shape.rank() == 2) {
// auto new_size = new_shape.extent(0) * new_shape.extent(1);
// RAFT_EXPECTS(new_size == h_mds.size(), "Cannot reshape array with size mismatch");

// return make_host_matrix_view<typename host_mdspan_type::element_type,
// typename host_mdspan_type::layout_type>(
// h_mds.data(), new_shape.extent(0), new_shape.extent(1));
// } else {
size_t new_size = 1;
for (size_t i = 0; i < new_shape.rank(); ++i) {
new_size *= new_shape.extent(i);
}
RAFT_EXPECTS(new_size <= h_mds.size(), "Cannot reshape array with size mismatch");

return detail::stdex::mdspan<typename host_mdspan_type::element_type,
decltype(new_shape),
typename host_mdspan_type::layout_type,
typename host_mdspan_type::accessor_type>(h_mds.data(), new_shape);
// }
}

// template <typename device_mdspan_type,
Expand All @@ -859,10 +865,12 @@ auto reshape(host_mdspan_type h_mds, std::experimental::extents<Extents...> new_
// d_mds.size());
// }

// template <typename mdarray_type, std::enable_if_t<is_mdarray_v<mdarray_type>>* = nullptr>
// auto reshape(const mdarray_type& mda)
// {
// return reshape(mda.view());
// }
template <typename mdarray_type,
size_t... Extents,
std::enable_if_t<is_mdarray_v<mdarray_type>>* = nullptr>
auto reshape(const mdarray_type& mda, extents<Extents...> new_shape)
{
return reshape(mda.view(), new_shape);
}

} // namespace raft
67 changes: 63 additions & 4 deletions cpp/test/mdspan_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ TEST(MDSpan, TemplateAsserts) { test_template_asserts(); }

void test_host_flatten()
{
// flatten 3d host matrix
// flatten 3d host mdspan
{
using three_d_extents = stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent>;
using three_d_mdarray = host_mdarray<int, three_d_extents>;
Expand Down Expand Up @@ -116,7 +116,7 @@ TEST(MDArray, HostFlatten) { test_host_flatten(); }
void test_device_flatten()
{
raft::handle_t handle{};
// flatten 3d host matrix
// flatten 3d device mdspan
{
using three_d_extents = stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent>;
using three_d_mdarray = device_mdarray<int, three_d_extents>;
Expand All @@ -135,7 +135,7 @@ void test_device_flatten()
ASSERT_EQ(flat_view.extent(0), 27);
}

// flatten host vector
// flatten device vector
{
auto dv = make_device_vector<int>(27, handle.get_stream());
auto flat_view = flatten(dv.view());
Expand All @@ -146,7 +146,7 @@ void test_device_flatten()
ASSERT_EQ(dv.extent(0), flat_view.extent(0));
}

// flatten host scalar
// flatten device scalar
{
auto ds = make_device_scalar<int>(27, handle.get_stream());
auto flat_view = flatten(ds.view());
Expand All @@ -159,4 +159,63 @@ void test_device_flatten()

TEST(MDArray, DeviceFlatten) { test_device_flatten(); }

void test_host_reshape()
{
// reshape 3d host matrix to vector
{
using three_d_extents = stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent>;
using three_d_mdarray = host_mdarray<int, three_d_extents>;

three_d_extents extents{3, 3, 3};
three_d_mdarray::container_policy_type policy;
three_d_mdarray mda{extents, policy};

auto flat_view = reshape(mda, raft::extents<dynamic_extent>{27});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(flat_view),
host_vector_view<typename decltype(flat_view)::element_type,
typename decltype(flat_view)::layout_type>>,
"types not the same");

ASSERT_EQ(flat_view.extents().rank(), 1);
ASSERT_EQ(flat_view.extent(0), 27);
}

// reshape 4d host matrix to 2d
{
using four_d_extents =
stdex::extents<dynamic_extent, dynamic_extent, dynamic_extent, dynamic_extent>;
using four_d_mdarray = host_mdarray<int, four_d_extents>;

four_d_extents extents{2, 2, 2, 2};
four_d_mdarray::container_policy_type policy;
four_d_mdarray mda{extents, policy};

auto matrix = reshape(mda, raft::extents<dynamic_extent, dynamic_extent>{4, 4});
// this confirms aliasing works as intended
static_assert(std::is_same_v<decltype(matrix),
host_matrix_view<typename decltype(matrix)::element_type,
typename decltype(matrix)::layout_type>>,
"types not the same");

ASSERT_EQ(matrix.extents().rank(), 2);
ASSERT_EQ(matrix.extent(0), 4);
ASSERT_EQ(matrix.extent(1), 4);
}

// shrink host vector
{
auto hv = make_host_vector<int>(27);
auto shrunk_vector = reshape(hv.view(), raft::extents<dynamic_extent>(20));

static_assert(std::is_same_v<decltype(hv.view()), decltype(shrunk_vector)>,
"types not the same");

ASSERT_EQ(hv.extents().rank(), shrunk_vector.extents().rank());
ASSERT_EQ(shrunk_vector.extent(0), 20);
}
}

TEST(MDArray, HostReshape) { test_host_reshape(); }

} // namespace raft