Skip to content

Commit

Permalink
Add strings::repeat_strings API that can repeat each string a diffe…
Browse files Browse the repository at this point in the history
…rent number of times (#8561)

This work is requested from the Spark team, which is also a follow up work on #8423 so that cudf's `strings::repeat_strings` fully supports `StringRepeat` SQL expression in Apache Spark.

Note that this API requires to explicitly implement overflow check for the size of the output strings column, as it is not trivial and can't be performed outside of cudf.

This PR also rewrites some existing code, including renaming variables and changes in doxygen.

### Follow up works depending on this PR:
 * Benchmark: #8589
 * Java binding: #8572

Authors:
  - Nghia Truong (https://github.com/ttnghia)

Approvers:
  - Vukasin Milovanovic (https://github.com/vuule)
  - Jason Lowe (https://github.com/jlowe)
  - Conor Hoekstra (https://github.com/codereport)
  - David Wendt (https://github.com/davidwendt)

URL: #8561
  • Loading branch information
ttnghia authored Jul 20, 2021
1 parent b490a32 commit 799f688
Show file tree
Hide file tree
Showing 4 changed files with 892 additions and 122 deletions.
112 changes: 96 additions & 16 deletions cpp/include/cudf/strings/repeat_strings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ namespace strings {
/**
* @brief Repeat the given string scalar by a given number of times.
*
* For a given string scalar, an output string scalar is generated by repeating the input string by
* a number of times given by the @p `repeat_times` parameter. If `repeat_times` is not a positive
* value, an empty (valid) string scalar will be returned. An invalid input scalar will always
* result in an invalid output scalar regardless of the value of `repeat_times` parameter.
* An output string scalar is generated by repeating the input string by a number of times given by
* the @p `repeat_times` parameter.
*
* In special cases:
* - If @p `repeat_times` is not a positive value, an empty (valid) string scalar will be returned.
* - An invalid input scalar will always result in an invalid output scalar regardless of the
* value of @p `repeat_times` parameter.
*
* @code{.pseudo}
* Example:
Expand All @@ -47,26 +50,29 @@ namespace strings {
* (i.e., `input.size() * repeat_times > numeric_limits<size_type>::max()`).
*
* @param input The scalar containing the string to repeat.
* @param repeat_times The number of times the `input` string is copied to the output.
* @param repeat_times The number of times the input string is repeated.
* @param mr Device memory resource used to allocate the returned string scalar.
* @return New string scalar in which the string is repeated from the input.
* @return New string scalar in which the input string is repeated.
*/
std::unique_ptr<string_scalar> repeat_strings(
std::unique_ptr<string_scalar> repeat_string(
string_scalar const& input,
size_type repeat_times,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Repeat each string in the given strings column by a given number of times.
*
* For a given strings column, an output strings column is generated by repeating each string from
* the input by a number of times given by the @p `repeat_times` parameter. If `repeat_times` is not
* a positive value, all the rows of the output strings column will be an empty string. Any null row
* will result in a null row regardless of the value of `repeat_times` parameter.
* An output strings column is generated by repeating each string from the input strings column by a
* number of times given by the @p `repeat_times` parameter.
*
* In special cases:
* - If @p `repeat_times` is not a positive number, a non-null input string will always result in
* an empty output string.
* - A null input string will always result in a null output string regardless of the value of the
* @p `repeat_times` parameter.
*
* Note that this function cannot handle the cases when the size of the output column exceeds the
* maximum value that can be indexed by size_type (offset_type). In such situations, an exception
* may be thrown, or the output result is undefined.
* The caller is responsible for checking the output column size will not exceed the maximum size of
* a strings column (number of total characters is less than the max size_type value).
*
* @code{.pseudo}
* Example:
Expand All @@ -76,15 +82,89 @@ std::unique_ptr<string_scalar> repeat_strings(
* @endcode
*
* @param input The column containing strings to repeat.
* @param repeat_times The number of times each input string is copied to the output.
* @param repeat_times The number of times each input string is repeated.
* @param mr Device memory resource used to allocate the returned strings column.
* @return New column with concatenated results.
* @return New column containing the repeated strings.
*/
std::unique_ptr<column> repeat_strings(
strings_column_view const& input,
size_type repeat_times,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Repeat each string in the given strings column by the numbers of times given in another
* numeric column.
*
* An output strings column is generated by repeating each of the input string by a number of times
* given by the corresponding row in a @p `repeat_times` numeric column. The computational time can
* be reduced if sizes of the output strings are known and provided.
*
* In special cases:
* - Any null row (from either the input strings column or the `repeat_times` column) will always
* result in a null output string.
* - If any value in the `repeat_times` column is not a positive number and its corresponding input
* string is not null, the output string will be an empty string.
*
* The caller is responsible for checking the output column size will not exceed the maximum size of
* a strings column (number of total characters is less than the max size_type value).
*
* @code{.pseudo}
* Example:
* strs = ['aa', null, '', 'bbc-']
* repeat_times = [ 1, 2, 3, 4 ]
* out = repeat_strings(strs, repeat_times)
* out is ['aa', null, '', 'bbc-bbc-bbc-bbc-']
* @endcode
*
* @throw cudf::logic_error if the input `repeat_times` column has data type other than integer.
* @throw cudf::logic_error if the input columns have different sizes.
*
* @param input The column containing strings to repeat.
* @param repeat_times The column containing numbers of times that the corresponding input strings
* are repeated.
* @param output_strings_sizes The optional column containing pre-computed sizes of the output
* strings.
* @param mr Device memory resource used to allocate the returned strings column.
* @return New column containing the repeated strings.
*/
std::unique_ptr<column> repeat_strings(
strings_column_view const& input,
column_view const& repeat_times,
std::optional<column_view> output_strings_sizes = std::nullopt,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Compute sizes of the output strings if each string in the input strings column
* is repeated by the numbers of times given in another numeric column.
*
* The output column storing string output sizes is not nullable. These string sizes are
* also summed up and returned (in an `int64_t` value), which can be used to detect if the input
* strings column can be safely repeated without data corruption due to overflow in string indexing.
*
* @code{.pseudo}
* Example:
* strs = ['aa', null, '', 'bbc-']
* repeat_times = [ 1, 2, 3, 4 ]
* [output_sizes, total_size] = repeat_strings_output_sizes(strs, repeat_times)
* out is [2, 0, 0, 16], and total_size = 18
* @endcode
*
* @throw cudf::logic_error if the input `repeat_times` column has data type other than integer.
* @throw cudf::logic_error if the input columns have different sizes.
*
* @param input The column containing strings to repeat.
* @param repeat_times The column containing numbers of times that the corresponding input strings
* are repeated.
* @param mr Device memory resource used to allocate the returned strings column.
* @return A pair with the first item is an int32_t column containing sizes of the output strings,
* and the second item is an int64_t number containing the total sizes (in bytes) of the
* output strings column.
*/
std::pair<std::unique_ptr<column>, int64_t> repeat_strings_output_sizes(
strings_column_view const& input,
column_view const& repeat_times,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
Loading

0 comments on commit 799f688

Please sign in to comment.