Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add strings::repeat_strings API that can repeat each string a different number of times #8561

Merged
merged 37 commits into from
Jul 20, 2021
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f0b5ff7
Add doxygen
ttnghia Jun 17, 2021
a61c1e7
Finish implementation
ttnghia Jun 18, 2021
2639f9a
Finish unit tests
ttnghia Jun 18, 2021
dbbfbf9
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jun 18, 2021
0dec873
Fix merge conflicts
ttnghia Jun 18, 2021
143f853
Rename parameter back to `input`
ttnghia Jun 21, 2021
b39bb06
Fix typo
ttnghia Jun 21, 2021
ae40591
Rewrite using type_dispatcher for different integer types
ttnghia Jun 21, 2021
8612f61
Fix comment typo
ttnghia Jun 21, 2021
3dfec42
Remove input check for int32_t data type
ttnghia Jun 21, 2021
d230498
Remove bool type from the expecting types for `repeat_times` data type
ttnghia Jun 21, 2021
f534372
Implement overflow check for the new API, as it can't be done outside…
ttnghia Jun 21, 2021
554d20d
Update doxygen
ttnghia Jun 21, 2021
d00ba01
Add typed tests for various types of `repeat_times` column
ttnghia Jun 21, 2021
5b5c2a4
Fix doxygen
ttnghia Jun 21, 2021
d488eca
Simplify overflow checking
ttnghia Jun 21, 2021
8498fc6
Just re-order code
ttnghia Jun 21, 2021
855a774
Add a parameter to allow turning on/off overflow checking
ttnghia Jun 24, 2021
e5f5db8
Implement overflow checking
ttnghia Jun 25, 2021
50f05fd
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jun 25, 2021
0bcf8d8
Redesign the API and update doxygen
ttnghia Jun 25, 2021
c7b7c3b
Add an optional column of pre-computed output strings offsets
ttnghia Jul 7, 2021
f6d7ee3
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 8, 2021
d22c4e5
Finish implementation
ttnghia Jul 8, 2021
6124e83
Fix JNI
ttnghia Jul 8, 2021
90517aa
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 8, 2021
7795f54
Cleanup
ttnghia Jul 8, 2021
9158e4f
Remove duplicate code
ttnghia Jul 9, 2021
5e37782
Add test for computing string output sizes that causes overflow
ttnghia Jul 9, 2021
4dfca75
Fix test build error
ttnghia Jul 9, 2021
95cb6c0
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 9, 2021
91c414e
Simple fix comment typo
ttnghia Jul 9, 2021
00ad095
Address review comments, fixing doxygen and some code improvements
ttnghia Jul 19, 2021
b7956e4
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 19, 2021
bebb9e6
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 19, 2021
5c554c1
Cleanup header
ttnghia Jul 20, 2021
b294f52
Merge branch 'branch-21.08' into repeat_strings
ttnghia Jul 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement overflow check for the new API, as it can't be done outside…
… of cudf
  • Loading branch information
ttnghia committed Jun 21, 2021
commit f5343722ff2297aa54a2d61d3c1a1df933128bd5
73 changes: 66 additions & 7 deletions cpp/src/strings/repeat_strings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,50 @@ std::unique_ptr<column> repeat_strings(strings_column_view const& input,
}

namespace {
/**
* @brief Check if the size of the output strings column exceeds the maximum indexable value.
*/
template <typename SizeCompFunc>
bool is_output_overflow(SizeCompFunc size_comp_fn,
size_type strings_count,
rmm::cuda_stream_view stream)
{
// Firstly, compute size of the output strings.
auto string_sizes = rmm::device_uvector<size_type>(strings_count + 1, stream);
size_comp_fn.d_offsets = string_sizes.begin();

thrust::for_each_n(rmm::exec_policy(stream),
thrust::make_counting_iterator<size_type>(0),
strings_count,
size_comp_fn);

// Compute offsets of the output strings.
// If there is overflow then we write a marker value `invalid_offset` to the offsets.
static constexpr size_type invalid_offset = std::numeric_limits<size_type>::lowest();
static constexpr int64_t max_offset = static_cast<int64_t>(std::numeric_limits<size_type>::max());
thrust::exclusive_scan(rmm::exec_policy(stream),
string_sizes.begin(),
string_sizes.end(),
string_sizes.begin(),
size_type{0},
[] __device__(auto const lhs, auto const rhs) {
// If there was already overflow...
if (lhs == invalid_offset || rhs == invalid_offset) {
return invalid_offset;
}
auto const sum = static_cast<int64_t>(lhs) + static_cast<int64_t>(rhs);
if (sum > max_offset) { return invalid_offset; }
return static_cast<size_type>(sum);
});

auto const invalid_count =
thrust::count_if(rmm::exec_policy(stream),
string_sizes.begin(),
string_sizes.end(),
[] __device__(auto const offset) { return offset == invalid_offset; });
return invalid_count > 0;
}

/**
* @brief Functor to compute string sizes and repeat the input strings, each string is repeated by a
* separate number of times.
Expand Down Expand Up @@ -212,7 +256,7 @@ struct compute_size_and_repeat_separately_fn {
repeat_times > 0 ? repeat_times * strings_dv.element<string_view>(idx).size_bytes() : 0;

// We will allocate memory for `d_validities` only when both input columns have nulls.
if (strings_has_nulls && rtimes_has_nulls) { d_validities[idx] = is_valid; }
if (d_validities) { d_validities[idx] = is_valid; }
}

if (d_chars && repeat_times > 0) {
Expand All @@ -238,19 +282,27 @@ struct compute_size_and_repeat_separately_fn {
struct dispatch_repeat_strings_separately_fn {
template <class T, std::enable_if_t<cudf::is_index_type<T>()>* = nullptr>
std::tuple<std::unique_ptr<column>, std::unique_ptr<column>, rmm::device_buffer, size_type>
operator()(strings_column_view const& input,
operator()(size_type strings_count,
strings_column_view const& input,
column_view const& repeat_times,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
auto const strings_count = input.size();
auto const strings_dv_ptr = column_device_view::create(input.parent(), stream);
auto const repeat_times_dv_ptr = column_device_view::create(repeat_times, stream);
auto const strings_has_nulls = input.has_nulls();
auto const rtimes_has_nulls = repeat_times.has_nulls();
auto const fn = compute_size_and_repeat_separately_fn<T>{
*strings_dv_ptr, *repeat_times_dv_ptr, strings_has_nulls, rtimes_has_nulls};

// Check for overflow (whether the total size of the output strings column exceeds
// numeric_limits<size_type>::max().
if (is_output_overflow(fn, strings_count, stream)) {
CUDF_FAIL(
"Size of the output strings column exceeds the maximum value that can be indexed by "
"size_type (offset_type)");
}

// Repeat the strings in each row.
// Note that this cannot handle the cases when the size of the output column exceeds the maximum
// value that can be indexed by size_type (offset_type).
Expand Down Expand Up @@ -279,12 +331,13 @@ struct dispatch_repeat_strings_separately_fn {

template <class T, std::enable_if_t<!cudf::is_index_type<T>()>* = nullptr>
std::tuple<std::unique_ptr<column>, std::unique_ptr<column>, rmm::device_buffer, size_type>
operator()(strings_column_view const&,
operator()(size_type,
strings_column_view const&,
column_view const&,
rmm::cuda_stream_view,
rmm::mr::device_memory_resource*) const
{
CUDF_FAIL("repeat_strings is expecting an integer type for the `repeat_times` input column.");
CUDF_FAIL("repeat_strings expects an integer type for the `repeat_times` input column.");
}
};

Expand All @@ -300,8 +353,14 @@ std::unique_ptr<column> repeat_strings(strings_column_view const& input,
auto const strings_count = input.size();
if (strings_count == 0) { return make_empty_column(data_type{type_id::STRING}); }

auto [offsets_column, chars_column, null_mask, null_count] = type_dispatcher(
repeat_times.type(), dispatch_repeat_strings_separately_fn{}, input, repeat_times, stream, mr);
auto [offsets_column, chars_column, null_mask, null_count] =
type_dispatcher(repeat_times.type(),
dispatch_repeat_strings_separately_fn{},
strings_count,
input,
repeat_times,
stream,
mr);
return make_strings_column(strings_count,
std::move(offsets_column),
std::move(chars_column),
Expand Down