diff --git a/cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md b/cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md index 23b129fdf4b..05f8e4585cc 100644 --- a/cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md +++ b/cpp/doxygen/developer_guide/DEVELOPER_GUIDE.md @@ -943,13 +943,14 @@ Use the `CUDF_EXPECTS` macro to enforce runtime conditions necessary for correct Example usage: ```c++ -CUDF_EXPECTS(lhs.type() == rhs.type(), "Column type mismatch"); +CUDF_EXPECTS(cudf::have_same_types(lhs, rhs), "Type mismatch", cudf::data_type_error); ``` The first argument is the conditional expression expected to resolve to `true` under normal -conditions. If the conditional evaluates to `false`, then an error has occurred and an instance of -`cudf::logic_error` is thrown. The second argument to `CUDF_EXPECTS` is a short description of the -error that has occurred and is used for the exception's `what()` message. +conditions. The second argument to `CUDF_EXPECTS` is a short description of the error that has +occurred and is used for the exception's `what()` message. If the conditional evaluates to +`false`, then an error has occurred and an instance of the exception class in the third argument +(or the default, `cudf::logic_error`) is thrown. There are times where a particular code path, if reached, should indicate an error no matter what. For example, often the `default` case of a `switch` statement represents an invalid alternative. @@ -1048,6 +1049,12 @@ types such as numeric types and timestamps/durations, adding support for nested Enabling an algorithm differently for different types uses either template specialization or SFINAE, as discussed in [Specializing Type-Dispatched Code Paths](#specializing-type-dispatched-code-paths). +## Comparing Data Types + +When comparing the data types of two columns or scalars, do not directly compare +`a.type() == b.type()`. Nested types such as lists of structs of integers will not be handled +properly if only the top level type is compared. Instead, use the `cudf::have_same_types` function. + # Type Dispatcher libcudf stores data (for columns and scalars) "type erased" in `void*` device memory. This diff --git a/cpp/include/cudf/detail/scatter.cuh b/cpp/include/cudf/detail/scatter.cuh index 7eb661f7833..80bc87731ca 100644 --- a/cpp/include/cudf/detail/scatter.cuh +++ b/cpp/include/cudf/detail/scatter.cuh @@ -29,7 +29,9 @@ #include #include #include +#include #include +#include #include #include @@ -213,8 +215,9 @@ struct column_scatterer_impl { // check the keys match dictionary_column_view const source(source_in); dictionary_column_view const target(target_in); - CUDF_EXPECTS(source.keys().type() == target.keys().type(), - "scatter dictionary keys must be the same type"); + CUDF_EXPECTS(cudf::have_same_types(source.keys(), target.keys()), + "scatter dictionary keys must be the same type", + cudf::data_type_error); // first combine keys so both dictionaries have the same set auto target_matched = dictionary::detail::add_keys(target, source.keys(), stream, mr); diff --git a/cpp/include/cudf/lists/detail/scatter.cuh b/cpp/include/cudf/lists/detail/scatter.cuh index d0d5b1ad823..c550ad5b94f 100644 --- a/cpp/include/cudf/lists/detail/scatter.cuh +++ b/cpp/include/cudf/lists/detail/scatter.cuh @@ -101,7 +101,7 @@ std::unique_ptr scatter_impl(rmm::device_uvector cons rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(column_types_equal(source, target), "Mismatched column types."); + CUDF_EXPECTS(have_same_types(source, target), "Mismatched column types."); auto const child_column_type = lists_column_view(target).child().type(); diff --git a/cpp/include/cudf/table/table_view.hpp b/cpp/include/cudf/table/table_view.hpp index 4f3b23747e6..ad12b1eef4e 100644 --- a/cpp/include/cudf/table/table_view.hpp +++ b/cpp/include/cudf/table/table_view.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -339,15 +339,6 @@ bool has_nested_nullable_columns(table_view const& input); */ std::vector get_nullable_columns(table_view const& table); -/** - * @brief Checks if two `table_view`s have columns of same types - * - * @param lhs left-side table_view operand - * @param rhs right-side table_view operand - * @return boolean comparison result - */ -bool have_same_types(table_view const& lhs, table_view const& rhs); - /** * @brief Copy column_views from a table_view into another table_view according to * a column indices map. diff --git a/cpp/include/cudf/utilities/type_checks.hpp b/cpp/include/cudf/utilities/type_checks.hpp index b925fc8ae92..fd3b0581c11 100644 --- a/cpp/include/cudf/utilities/type_checks.hpp +++ b/cpp/include/cudf/utilities/type_checks.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,11 +16,16 @@ #pragma once #include +#include + +#include namespace cudf { /** - * @brief Compares the type of two `column_view`s + * @brief Compare the types of two `column_view`s + * + * @deprecated Since 24.06. Use cudf::have_same_types instead. * * This function returns true if the type of `lhs` equals that of `rhs`. * - For fixed point types, the scale is compared. @@ -34,10 +39,11 @@ namespace cudf { * @param rhs The second `column_view` to compare * @return true if column types match */ -bool column_types_equal(column_view const& lhs, column_view const& rhs); +[[deprecated]] bool column_types_equal(column_view const& lhs, column_view const& rhs); /** * @brief Compare the type IDs of two `column_view`s + * * This function returns true if the type of `lhs` equals that of `rhs`. * - For fixed point types, the scale is ignored. * @@ -47,4 +53,98 @@ bool column_types_equal(column_view const& lhs, column_view const& rhs); */ bool column_types_equivalent(column_view const& lhs, column_view const& rhs); +/** + * @brief Compares the type of two `column_view`s + * + * This function returns true if the type of `lhs` equals that of `rhs`. + * - For fixed point types, the scale is compared. + * - For dictionary types, the type of the keys are compared if both are + * non-empty columns. + * - For lists types, the type of child columns are compared recursively. + * - For struct types, the type of each field are compared in order. + * - For all other types, the `id` of `data_type` is compared. + * + * @param lhs The first `column_view` to compare + * @param rhs The second `column_view` to compare + * @return true if types match + */ +bool have_same_types(column_view const& lhs, column_view const& rhs); + +/** + * @brief Compare the types of a `column_view` and a `scalar` + * + * This function returns true if the type of `lhs` equals that of `rhs`. + * - For fixed point types, the scale is compared. + * - For dictionary column types, the type of the keys is compared to the + * scalar type. + * - For lists types, the types of child columns are compared recursively. + * - For struct types, the types of each field are compared in order. + * - For all other types, the `id` of `data_type` is compared. + * + * @param lhs The `column_view` to compare + * @param rhs The `scalar` to compare + * @return true if types match + */ +bool have_same_types(column_view const& lhs, scalar const& rhs); + +/** + * @brief Compare the types of a `scalar` and a `column_view` + * + * This function returns true if the type of `lhs` equals that of `rhs`. + * - For fixed point types, the scale is compared. + * - For dictionary column types, the type of the keys is compared to the + * scalar type. + * - For lists types, the types of child columns are compared recursively. + * - For struct types, the types of each field are compared in order. + * - For all other types, the `id` of `data_type` is compared. + * + * @param lhs The `scalar` to compare + * @param rhs The `column_view` to compare + * @return true if types match + */ +bool have_same_types(scalar const& lhs, column_view const& rhs); + +/** + * @brief Compare the types of two `scalar`s + * + * This function returns true if the type of `lhs` equals that of `rhs`. + * - For fixed point types, the scale is compared. + * - For lists types, the types of child columns are compared recursively. + * - For struct types, the types of each field are compared in order. + * - For all other types, the `id` of `data_type` is compared. + * + * @param lhs The first `scalar` to compare + * @param rhs The second `scalar` to compare + * @return true if types match + */ +bool have_same_types(scalar const& lhs, scalar const& rhs); + +/** + * @brief Checks if two `table_view`s have columns of same types + * + * @param lhs left-side table_view operand + * @param rhs right-side table_view operand + * @return boolean comparison result + */ +bool have_same_types(table_view const& lhs, table_view const& rhs); + +/** + * @brief Compare the types of a range of `column_view` or `scalar` objects + * + * This function returns true if all objects in the range have the same type, in the sense of + * cudf::have_same_types. + * + * @tparam ForwardIt Forward iterator + * @param first The first iterator + * @param last The last iterator + * @return true if all types match + */ +template +inline bool all_have_same_types(ForwardIt first, ForwardIt last) +{ + return first == last || std::all_of(std::next(first), last, [want = *first](auto const& c) { + return cudf::have_same_types(want, c); + }); +} + } // namespace cudf diff --git a/cpp/src/copying/concatenate.cu b/cpp/src/copying/concatenate.cu index 7c57be8e7c0..b1136a9eeb3 100644 --- a/cpp/src/copying/concatenate.cu +++ b/cpp/src/copying/concatenate.cu @@ -30,6 +30,8 @@ #include #include #include +#include +#include #include #include @@ -461,12 +463,9 @@ void traverse_children::operator()(host_span */ void bounds_and_type_check(host_span cols, rmm::cuda_stream_view stream) { - CUDF_EXPECTS(std::all_of(cols.begin(), - cols.end(), - [expected_type = cols.front().type()](auto const& c) { - return c.type() == expected_type; - }), - "Type mismatch in columns to concatenate."); + CUDF_EXPECTS(cudf::all_have_same_types(cols.begin(), cols.end()), + "Type mismatch in columns to concatenate.", + cudf::data_type_error); // total size of all concatenated rows size_t const total_row_count = diff --git a/cpp/src/copying/copy.cu b/cpp/src/copying/copy.cu index 92fb2e61741..e86a1f8d6f1 100644 --- a/cpp/src/copying/copy.cu +++ b/cpp/src/copying/copy.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -362,9 +363,10 @@ std::unique_ptr copy_if_else(column_view const& lhs, CUDF_EXPECTS(boolean_mask.size() == lhs.size(), "Boolean mask column must be the same size as lhs and rhs columns", std::invalid_argument); - CUDF_EXPECTS(lhs.size() == rhs.size(), "Both columns must be of the size", std::invalid_argument); CUDF_EXPECTS( - lhs.type() == rhs.type(), "Both inputs must be of the same type", cudf::data_type_error); + lhs.size() == rhs.size(), "Both columns must be of the same size", std::invalid_argument); + CUDF_EXPECTS( + cudf::have_same_types(lhs, rhs), "Both inputs must be of the same type", cudf::data_type_error); return copy_if_else(lhs, rhs, lhs.has_nulls(), rhs.has_nulls(), boolean_mask, stream, mr); } @@ -378,11 +380,8 @@ std::unique_ptr copy_if_else(scalar const& lhs, CUDF_EXPECTS(boolean_mask.size() == rhs.size(), "Boolean mask column must be the same size as rhs column", std::invalid_argument); - - auto rhs_type = - cudf::is_dictionary(rhs.type()) ? cudf::dictionary_column_view(rhs).keys_type() : rhs.type(); CUDF_EXPECTS( - lhs.type() == rhs_type, "Both inputs must be of the same type", cudf::data_type_error); + cudf::have_same_types(rhs, lhs), "Both inputs must be of the same type", cudf::data_type_error); return copy_if_else(lhs, rhs, !lhs.is_valid(stream), rhs.has_nulls(), boolean_mask, stream, mr); } @@ -396,11 +395,8 @@ std::unique_ptr copy_if_else(column_view const& lhs, CUDF_EXPECTS(boolean_mask.size() == lhs.size(), "Boolean mask column must be the same size as lhs column", std::invalid_argument); - - auto lhs_type = - cudf::is_dictionary(lhs.type()) ? cudf::dictionary_column_view(lhs).keys_type() : lhs.type(); CUDF_EXPECTS( - lhs_type == rhs.type(), "Both inputs must be of the same type", cudf::data_type_error); + cudf::have_same_types(lhs, rhs), "Both inputs must be of the same type", cudf::data_type_error); return copy_if_else(lhs, rhs, lhs.has_nulls(), !rhs.is_valid(stream), boolean_mask, stream, mr); } @@ -412,7 +408,7 @@ std::unique_ptr copy_if_else(scalar const& lhs, rmm::device_async_resource_ref mr) { CUDF_EXPECTS( - lhs.type() == rhs.type(), "Both inputs must be of the same type", cudf::data_type_error); + cudf::have_same_types(lhs, rhs), "Both inputs must be of the same type", cudf::data_type_error); return copy_if_else( lhs, rhs, !lhs.is_valid(stream), !rhs.is_valid(stream), boolean_mask, stream, mr); } diff --git a/cpp/src/copying/copy_range.cu b/cpp/src/copying/copy_range.cu index d2ea7036952..dd18f99a3c8 100644 --- a/cpp/src/copying/copy_range.cu +++ b/cpp/src/copying/copy_range.cu @@ -32,6 +32,7 @@ #include #include #include +#include #include #include @@ -147,8 +148,9 @@ std::unique_ptr out_of_place_copy_range_dispatch::operator() copy_range(column_view const& source, (target_begin <= target.size() - (source_end - source_begin)), "Range is out of bounds.", std::out_of_range); - CUDF_EXPECTS(target.type() == source.type(), "Data type mismatch.", cudf::data_type_error); + CUDF_EXPECTS(cudf::have_same_types(target, source), "Data type mismatch.", cudf::data_type_error); return cudf::type_dispatcher( target.type(), diff --git a/cpp/src/copying/scatter.cu b/cpp/src/copying/scatter.cu index cfcbe4724df..993ee074f14 100644 --- a/cpp/src/copying/scatter.cu +++ b/cpp/src/copying/scatter.cu @@ -32,6 +32,8 @@ #include #include #include +#include +#include #include #include @@ -112,7 +114,7 @@ struct column_scalar_scatterer_impl { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const { - CUDF_EXPECTS(source.get().type() == target.type(), + CUDF_EXPECTS(cudf::have_same_types(target, source.get()), "scalar and column types must match", cudf::data_type_error); @@ -145,7 +147,7 @@ struct column_scalar_scatterer_impl { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const { - CUDF_EXPECTS(source.get().type() == target.type(), + CUDF_EXPECTS(cudf::have_same_types(target, source.get()), "scalar and column types must match", cudf::data_type_error); @@ -315,12 +317,7 @@ std::unique_ptr scatter(table_view const& source, CUDF_EXPECTS(scatter_map.size() <= source.num_rows(), "Size of scatter map must be equal to or less than source rows", std::invalid_argument); - CUDF_EXPECTS(std::equal(source.begin(), - source.end(), - target.begin(), - [](auto const& col1, auto const& col2) { - return col1.type().id() == col2.type().id(); - }), + CUDF_EXPECTS(cudf::have_same_types(source, target), "Column types do not match between source and target", cudf::data_type_error); CUDF_EXPECTS(not scatter_map.has_nulls(), "Scatter map contains nulls", std::invalid_argument); @@ -452,14 +449,9 @@ std::unique_ptr
boolean_mask_scatter(table_view const& input, "Mask must be of Boolean type", cudf::data_type_error); // Count valid pair of input and columns as per type at each column index i - CUDF_EXPECTS( - std::all_of(thrust::counting_iterator(0), - thrust::counting_iterator(target.num_columns()), - [&input, &target](auto index) { - return ((input.column(index).type().id()) == (target.column(index).type().id())); - }), - "Type mismatch in input column and target column", - cudf::data_type_error); + CUDF_EXPECTS(cudf::have_same_types(input, target), + "Type mismatch in input column and target column", + cudf::data_type_error); if (target.num_rows() != 0) { std::vector> out_columns(target.num_columns()); @@ -496,14 +488,13 @@ std::unique_ptr
boolean_mask_scatter( cudf::data_type_error); // Count valid pair of input and columns as per type at each column/scalar index i - CUDF_EXPECTS( - std::all_of(thrust::counting_iterator(0), - thrust::counting_iterator(target.num_columns()), - [&input, &target](auto index) { - return (input[index].get().type().id() == target.column(index).type().id()); - }), - "Type mismatch in input scalar and target column", - cudf::data_type_error); + CUDF_EXPECTS(std::all_of(thrust::counting_iterator(0), + thrust::counting_iterator(target.num_columns()), + [&input, &target](auto index) { + return cudf::have_same_types(target.column(index), input[index].get()); + }), + "Type mismatch in input scalar and target column", + cudf::data_type_error); if (target.num_rows() != 0) { std::vector> out_columns(target.num_columns()); diff --git a/cpp/src/copying/shift.cu b/cpp/src/copying/shift.cu index bdc741887f7..91254f21170 100644 --- a/cpp/src/copying/shift.cu +++ b/cpp/src/copying/shift.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -158,7 +159,7 @@ std::unique_ptr shift(column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == fill_value.type(), + CUDF_EXPECTS(cudf::have_same_types(input, fill_value), "shift requires each fill value type to match the corresponding column type.", cudf::data_type_error); diff --git a/cpp/src/dictionary/add_keys.cu b/cpp/src/dictionary/add_keys.cu index 5fd21ee0094..0ed9006f88b 100644 --- a/cpp/src/dictionary/add_keys.cu +++ b/cpp/src/dictionary/add_keys.cu @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -54,7 +56,8 @@ std::unique_ptr add_keys(dictionary_column_view const& dictionary_column { CUDF_EXPECTS(!new_keys.has_nulls(), "Keys must not have nulls"); auto old_keys = dictionary_column.keys(); // [a,b,c,d,f] - CUDF_EXPECTS(new_keys.type() == old_keys.type(), "Keys must be the same type"); + CUDF_EXPECTS( + cudf::have_same_types(new_keys, old_keys), "Keys must be the same type", cudf::data_type_error); // first, concatenate the keys together // [a,b,c,d,f] + [d,b,e] = [a,b,c,d,f,d,b,e] auto combined_keys = cudf::detail::concatenate( diff --git a/cpp/src/dictionary/detail/concatenate.cu b/cpp/src/dictionary/detail/concatenate.cu index 62a6c816493..fdc3d9d0ecf 100644 --- a/cpp/src/dictionary/detail/concatenate.cu +++ b/cpp/src/dictionary/detail/concatenate.cu @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -82,13 +84,13 @@ struct compute_children_offsets_fn { } /** - * @brief Return the first keys().type of the dictionary columns. + * @brief Return the first keys() of the dictionary columns. */ - data_type get_keys_type() + column_view get_keys() { auto const view(*std::find_if( columns_ptrs.begin(), columns_ptrs.end(), [](auto pcv) { return pcv->size() > 0; })); - return dictionary_column_view(*view).keys().type(); + return dictionary_column_view(*view).keys(); } /** @@ -214,14 +216,16 @@ std::unique_ptr concatenate(host_span columns, // concatenate the keys (and check the keys match) compute_children_offsets_fn child_offsets_fn{columns}; - auto keys_type = child_offsets_fn.get_keys_type(); + auto expected_keys = child_offsets_fn.get_keys(); std::vector keys_views(columns.size()); - std::transform(columns.begin(), columns.end(), keys_views.begin(), [keys_type](auto cv) { + std::transform(columns.begin(), columns.end(), keys_views.begin(), [expected_keys](auto cv) { auto dict_view = dictionary_column_view(cv); // empty column may not have keys so we create an empty column_view place-holder - if (dict_view.is_empty()) return column_view{keys_type, 0, nullptr, nullptr, 0}; + if (dict_view.is_empty()) return column_view{expected_keys.type(), 0, nullptr, nullptr, 0}; auto keys = dict_view.keys(); - CUDF_EXPECTS(keys.type() == keys_type, "key types of all dictionary columns must match"); + CUDF_EXPECTS(cudf::have_same_types(keys, expected_keys), + "key types of all dictionary columns must match", + cudf::data_type_error); return keys; }); auto all_keys = @@ -275,7 +279,7 @@ std::unique_ptr concatenate(host_span columns, // now recompute the indices values for the new keys_column; // the keys offsets (pair.first) are for mapping to the input keys - auto indices_column = type_dispatcher(keys_type, + auto indices_column = type_dispatcher(expected_keys.type(), dispatch_compute_indices{}, all_keys->view(), // old keys all_indices->view(), // old indices diff --git a/cpp/src/dictionary/remove_keys.cu b/cpp/src/dictionary/remove_keys.cu index 718ca419289..35387efa56b 100644 --- a/cpp/src/dictionary/remove_keys.cu +++ b/cpp/src/dictionary/remove_keys.cu @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -155,7 +157,9 @@ std::unique_ptr remove_keys(dictionary_column_view const& dictionary_col { CUDF_EXPECTS(!keys_to_remove.has_nulls(), "keys_to_remove must not have nulls"); auto const keys_view = dictionary_column.keys(); - CUDF_EXPECTS(keys_view.type() == keys_to_remove.type(), "keys types must match"); + CUDF_EXPECTS(cudf::have_same_types(keys_view, keys_to_remove), + "keys types must match", + cudf::data_type_error); // locate keys to remove by searching the keys column auto const matches = cudf::detail::contains(keys_to_remove, keys_view, stream, mr); diff --git a/cpp/src/dictionary/replace.cu b/cpp/src/dictionary/replace.cu index bb6b08c243d..bc17dfd4bab 100644 --- a/cpp/src/dictionary/replace.cu +++ b/cpp/src/dictionary/replace.cu @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -84,7 +86,9 @@ std::unique_ptr replace_nulls(dictionary_column_view const& input, { if (input.is_empty()) { return cudf::empty_like(input.parent()); } if (!input.has_nulls()) { return std::make_unique(input.parent(), stream, mr); } - CUDF_EXPECTS(input.keys().type() == replacement.keys().type(), "keys must match"); + CUDF_EXPECTS(cudf::have_same_types(input.keys(), replacement.keys()), + "keys must match", + cudf::data_type_error); CUDF_EXPECTS(replacement.size() == input.size(), "column sizes must match"); // first combine the keys so both input dictionaries have the same set @@ -119,7 +123,9 @@ std::unique_ptr replace_nulls(dictionary_column_view const& input, if (!input.has_nulls() || !replacement.is_valid(stream)) { return std::make_unique(input.parent(), stream, mr); } - CUDF_EXPECTS(input.keys().type() == replacement.type(), "keys must match scalar type"); + CUDF_EXPECTS(cudf::have_same_types(input.parent(), replacement), + "keys must match scalar type", + cudf::data_type_error); // first add the replacement to the keys so only the indices need to be processed auto input_matched = dictionary::detail::add_keys( diff --git a/cpp/src/dictionary/search.cu b/cpp/src/dictionary/search.cu index 680eadddba8..231619836f9 100644 --- a/cpp/src/dictionary/search.cu +++ b/cpp/src/dictionary/search.cu @@ -19,7 +19,9 @@ #include #include #include +#include #include +#include #include #include @@ -72,10 +74,12 @@ struct find_index_fn { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const { - if (!key.is_valid(stream)) + if (!key.is_valid(stream)) { return type_dispatcher(input.indices().type(), dispatch_scalar_index{}, 0, false, stream, mr); - CUDF_EXPECTS(input.keys().type() == key.type(), - "search key type must match dictionary keys type"); + } + CUDF_EXPECTS(cudf::have_same_types(input.parent(), key), + "search key type must match dictionary keys type", + cudf::data_type_error); using ScalarType = cudf::scalar_type_t; auto find_key = static_cast(key).value(stream); @@ -114,10 +118,12 @@ struct find_insert_index_fn { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) const { - if (!key.is_valid(stream)) + if (!key.is_valid(stream)) { return type_dispatcher(input.indices().type(), dispatch_scalar_index{}, 0, false, stream, mr); - CUDF_EXPECTS(input.keys().type() == key.type(), - "search key type must match dictionary keys type"); + } + CUDF_EXPECTS(cudf::have_same_types(input.parent(), key), + "search key type must match dictionary keys type", + cudf::data_type_error); using ScalarType = cudf::scalar_type_t; auto find_key = static_cast(key).value(stream); diff --git a/cpp/src/dictionary/set_keys.cu b/cpp/src/dictionary/set_keys.cu index b56eec9401a..08a33d40abe 100644 --- a/cpp/src/dictionary/set_keys.cu +++ b/cpp/src/dictionary/set_keys.cu @@ -29,6 +29,8 @@ #include #include #include +#include +#include #include #include @@ -116,7 +118,6 @@ struct dispatch_compute_indices { } // namespace -// std::unique_ptr set_keys(dictionary_column_view const& dictionary_column, column_view const& new_keys, rmm::cuda_stream_view stream, @@ -124,7 +125,8 @@ std::unique_ptr set_keys(dictionary_column_view const& dictionary_column { CUDF_EXPECTS(!new_keys.has_nulls(), "keys parameter must not have nulls"); auto keys = dictionary_column.keys(); - CUDF_EXPECTS(keys.type() == new_keys.type(), "keys types must match"); + CUDF_EXPECTS( + cudf::have_same_types(keys, new_keys), "keys types must match", cudf::data_type_error); // copy the keys -- use cudf::distinct to make sure there are no duplicates, // then sort the results. diff --git a/cpp/src/filling/fill.cu b/cpp/src/filling/fill.cu index c4d786bd73b..1fc9ed31c09 100644 --- a/cpp/src/filling/fill.cu +++ b/cpp/src/filling/fill.cu @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -110,7 +111,7 @@ struct out_of_place_fill_range_dispatch { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == value.type(), "Data type mismatch."); + CUDF_EXPECTS(cudf::have_same_types(input, value), "Data type mismatch.", cudf::data_type_error); auto p_ret = std::make_unique(input, stream, mr); if (end != begin) { // otherwise no fill @@ -137,7 +138,7 @@ std::unique_ptr out_of_place_fill_range_dispatch::operator(); auto p_scalar = static_cast(&value); return cudf::strings::detail::fill( @@ -153,7 +154,8 @@ std::unique_ptr out_of_place_fill_range_dispatch::operator()(input, stream, mr); cudf::dictionary_column_view const target(input); - CUDF_EXPECTS(target.keys().type() == value.type(), "Data type mismatch."); + CUDF_EXPECTS( + cudf::have_same_types(target.parent(), value), "Data type mismatch.", cudf::data_type_error); // if the scalar is invalid, then just copy the column and fill the null mask if (!value.is_valid(stream)) { @@ -219,7 +221,8 @@ void fill_in_place(mutable_column_view& destination, "Range is out of bounds."); CUDF_EXPECTS(destination.nullable() || value.is_valid(stream), "destination should be nullable or value should be non-null."); - CUDF_EXPECTS(destination.type() == value.type(), "Data type mismatch."); + CUDF_EXPECTS( + cudf::have_same_types(destination, value), "Data type mismatch.", cudf::data_type_error); if (end != begin) { // otherwise no-op cudf::type_dispatcher( diff --git a/cpp/src/filling/sequence.cu b/cpp/src/filling/sequence.cu index f7067c3a91b..ee1745b8498 100644 --- a/cpp/src/filling/sequence.cu +++ b/cpp/src/filling/sequence.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -128,7 +129,9 @@ std::unique_ptr sequence(size_type size, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(init.type() == step.type(), "init and step must be of the same type."); + CUDF_EXPECTS(cudf::have_same_types(init, step), + "init and step must be of the same type.", + cudf::data_type_error); CUDF_EXPECTS(size >= 0, "size must be >= 0"); CUDF_EXPECTS(is_numeric(init.type()), "Input scalar types must be numeric"); diff --git a/cpp/src/groupby/groupby.cu b/cpp/src/groupby/groupby.cu index 73cb4efd283..e43dfcb4d98 100644 --- a/cpp/src/groupby/groupby.cu +++ b/cpp/src/groupby/groupby.cu @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -312,12 +313,15 @@ std::pair, std::unique_ptr
> groupby::shift( CUDF_FUNC_RANGE(); CUDF_EXPECTS(values.num_columns() == static_cast(fill_values.size()), "Mismatch number of fill_values and columns."); - CUDF_EXPECTS( - std::all_of(thrust::make_counting_iterator(0), - thrust::make_counting_iterator(values.num_columns()), - [&](auto i) { return values.column(i).type() == fill_values[i].get().type(); }), - "values and fill_value should have the same type."); - + CUDF_EXPECTS(std::equal(values.begin(), + values.end(), + fill_values.cbegin(), + fill_values.cend(), + [](auto const& col, auto const& scalar) { + return cudf::have_same_types(col, scalar.get()); + }), + "values and fill_value should have the same type.", + cudf::data_type_error); auto stream = cudf::get_default_stream(); std::vector> results; auto const& group_offsets = helper().group_offsets(stream); diff --git a/cpp/src/interop/dlpack.cpp b/cpp/src/interop/dlpack.cpp index 3109a36cbcf..78ddd7f5ad5 100644 --- a/cpp/src/interop/dlpack.cpp +++ b/cpp/src/interop/dlpack.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -231,9 +232,9 @@ DLManagedTensor* to_dlpack(table_view const& input, DLDataType const dltype = data_type_to_DLDataType(type); // Ensure all columns are the same type - CUDF_EXPECTS( - std::all_of(input.begin(), input.end(), [type](auto const& col) { return col.type() == type; }), - "All columns required to have same data type"); + CUDF_EXPECTS(cudf::all_have_same_types(input.begin(), input.end()), + "All columns required to have same data type", + cudf::data_type_error); // Ensure none of the columns have nulls CUDF_EXPECTS( diff --git a/cpp/src/join/hash_join.cu b/cpp/src/join/hash_join.cu index fbe16378e8c..b0184ff6a86 100644 --- a/cpp/src/join/hash_join.cu +++ b/cpp/src/join/hash_join.cu @@ -21,6 +21,8 @@ #include #include #include +#include +#include #include #include @@ -569,12 +571,9 @@ hash_join::compute_hash_join(cudf::table_view const& probe, std::make_unique>(0, stream, mr)); } - CUDF_EXPECTS(std::equal(std::cbegin(_build), - std::cend(_build), - std::cbegin(probe), - std::cend(probe), - [](auto const& b, auto const& p) { return b.type() == p.type(); }), - "Mismatch in joining column data types"); + CUDF_EXPECTS(cudf::have_same_types(_build, probe), + "Mismatch in joining column data types", + cudf::data_type_error); return probe_join_indices(probe, join, output_size, stream, mr); } diff --git a/cpp/src/labeling/label_bins.cu b/cpp/src/labeling/label_bins.cu index 1bfa7f39190..7ee1d540831 100644 --- a/cpp/src/labeling/label_bins.cu +++ b/cpp/src/labeling/label_bins.cu @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -208,8 +209,10 @@ std::unique_ptr label_bins(column_view const& input, rmm::device_async_resource_ref mr) { CUDF_FUNC_RANGE() - CUDF_EXPECTS((input.type() == left_edges.type()) && (input.type() == right_edges.type()), - "The input and edge columns must have the same types."); + CUDF_EXPECTS( + cudf::have_same_types(input, left_edges) && cudf::have_same_types(input, right_edges), + "The input and edge columns must have the same types.", + cudf::data_type_error); CUDF_EXPECTS(left_edges.size() == right_edges.size(), "The left and right edge columns must be of the same length."); CUDF_EXPECTS(!left_edges.has_nulls() && !right_edges.has_nulls(), diff --git a/cpp/src/lists/combine/concatenate_rows.cu b/cpp/src/lists/combine/concatenate_rows.cu index 38d299763a1..bc1b48b11cd 100644 --- a/cpp/src/lists/combine/concatenate_rows.cu +++ b/cpp/src/lists/combine/concatenate_rows.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -204,12 +205,11 @@ std::unique_ptr concatenate_rows(table_view const& input, std::all_of(input.begin(), input.end(), [](column_view const& col) { return col.type().id() == cudf::type_id::LIST; }), - "All columns of the input table must be of lists column type."); - CUDF_EXPECTS( - std::all_of(std::next(input.begin()), - input.end(), - [a = *input.begin()](column_view const& b) { return column_types_equal(a, b); }), - "The types of entries in the input columns must be the same."); + "All columns of the input table must be of list column type.", + cudf::data_type_error); + CUDF_EXPECTS(cudf::all_have_same_types(input.begin(), input.end()), + "The types of entries in the input columns must be the same.", + cudf::data_type_error); auto const num_rows = input.num_rows(); auto const num_cols = input.num_columns(); diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu index 4737b077deb..f03d394d6d7 100644 --- a/cpp/src/lists/contains.cu +++ b/cpp/src/lists/contains.cu @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -194,7 +195,7 @@ std::unique_ptr dispatch_index_of(lists_column_view const& lists, // comparisons. auto const child = lists.child(); - CUDF_EXPECTS(child.type() == search_keys.type(), + CUDF_EXPECTS(cudf::have_same_types(child, search_keys), "Type/Scale of search key does not match list column element type.", cudf::data_type_error); CUDF_EXPECTS(search_keys.type().id() != type_id::EMPTY, "Type cannot be empty."); diff --git a/cpp/src/lists/sequences.cu b/cpp/src/lists/sequences.cu index cb14ae7619b..7d57d8ddb60 100644 --- a/cpp/src/lists/sequences.cu +++ b/cpp/src/lists/sequences.cu @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include #include @@ -139,15 +141,18 @@ std::unique_ptr sequences(column_view const& starts, "starts and sizes input columns must not have nulls."); CUDF_EXPECTS(starts.size() == sizes.size(), "starts and sizes input columns must have the same number of rows."); - CUDF_EXPECTS(cudf::is_index_type(sizes.type()), "Input sizes column must be of integer types."); + CUDF_EXPECTS(cudf::is_index_type(sizes.type()), + "Input sizes column must be of integer types.", + cudf::data_type_error); if (steps) { auto const& steps_cv = steps.value(); CUDF_EXPECTS(!steps_cv.has_nulls(), "steps input column must not have nulls."); CUDF_EXPECTS(starts.size() == steps_cv.size(), "starts and steps input columns must have the same number of rows."); - CUDF_EXPECTS(starts.type() == steps_cv.type(), - "starts and steps input columns must have the same type."); + CUDF_EXPECTS(cudf::have_same_types(starts, steps_cv), + "starts and steps input columns must have the same type.", + cudf::data_type_error); } auto const n_lists = starts.size(); diff --git a/cpp/src/lists/set_operations.cu b/cpp/src/lists/set_operations.cu index f3352a3a52d..1d18b8c677c 100644 --- a/cpp/src/lists/set_operations.cu +++ b/cpp/src/lists/set_operations.cu @@ -52,7 +52,7 @@ namespace { void check_compatibility(lists_column_view const& lhs, lists_column_view const& rhs) { CUDF_EXPECTS(lhs.size() == rhs.size(), "The input lists column must have the same size."); - CUDF_EXPECTS(column_types_equal(lhs.child(), rhs.child()), + CUDF_EXPECTS(have_same_types(lhs.child(), rhs.child()), "The input lists columns must have children having the same type structure"); } diff --git a/cpp/src/merge/merge.cu b/cpp/src/merge/merge.cu index 5a3be259ed9..630cf328579 100644 --- a/cpp/src/merge/merge.cu +++ b/cpp/src/merge/merge.cu @@ -34,6 +34,7 @@ #include #include #include +#include #include #include diff --git a/cpp/src/reductions/reductions.cpp b/cpp/src/reductions/reductions.cpp index d764ea7559f..cde0274339a 100644 --- a/cpp/src/reductions/reductions.cpp +++ b/cpp/src/reductions/reductions.cpp @@ -28,6 +28,8 @@ #include #include #include +#include +#include #include #include @@ -154,8 +156,9 @@ std::unique_ptr reduce(column_view const& col, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(!init.has_value() || col.type() == init.value().get().type(), - "column and initial value must be the same type"); + CUDF_EXPECTS(!init.has_value() || cudf::have_same_types(col, init.value().get()), + "column and initial value must be the same type", + cudf::data_type_error); if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT || agg.kind == aggregation::MIN || agg.kind == aggregation::MAX || agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) { diff --git a/cpp/src/reductions/segmented/reductions.cpp b/cpp/src/reductions/segmented/reductions.cpp index dee16b3e503..1ae344dcace 100644 --- a/cpp/src/reductions/segmented/reductions.cpp +++ b/cpp/src/reductions/segmented/reductions.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -112,8 +113,9 @@ std::unique_ptr segmented_reduce(column_view const& segmented_values, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(!init.has_value() || segmented_values.type() == init.value().get().type(), - "column and initial value must be the same type"); + CUDF_EXPECTS(!init.has_value() || cudf::have_same_types(segmented_values, init.value().get()), + "column and initial value must be the same type", + cudf::data_type_error); if (init.has_value() && !(agg.kind == aggregation::SUM || agg.kind == aggregation::PRODUCT || agg.kind == aggregation::MIN || agg.kind == aggregation::MAX || agg.kind == aggregation::ANY || agg.kind == aggregation::ALL)) { diff --git a/cpp/src/replace/clamp.cu b/cpp/src/replace/clamp.cu index 31ffc76a4a5..cb3caf9d068 100644 --- a/cpp/src/replace/clamp.cu +++ b/cpp/src/replace/clamp.cu @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include #include @@ -192,7 +194,9 @@ struct dispatch_clamp { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(lo.type() == input.type(), "mismatching types of scalar and input"); + CUDF_EXPECTS(cudf::have_same_types(input, lo), + "mismatching types of scalar and input", + cudf::data_type_error); auto lo_itr = make_optional_iterator(lo, nullate::YES{}); auto hi_itr = make_optional_iterator(hi, nullate::YES{}); @@ -316,9 +320,14 @@ std::unique_ptr clamp(column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(lo.type() == hi.type(), "mismatching types of limit scalars"); - CUDF_EXPECTS(lo_replace.type() == hi_replace.type(), "mismatching types of replace scalars"); - CUDF_EXPECTS(lo.type() == lo_replace.type(), "mismatching types of limit and replace scalars"); + CUDF_EXPECTS( + cudf::have_same_types(lo, hi), "mismatching types of limit scalars", cudf::data_type_error); + CUDF_EXPECTS(cudf::have_same_types(lo_replace, hi_replace), + "mismatching types of replace scalars", + cudf::data_type_error); + CUDF_EXPECTS(cudf::have_same_types(lo, lo_replace), + "mismatching types of limit and replace scalars", + cudf::data_type_error); if ((not lo.is_valid(stream) and not hi.is_valid(stream)) or (input.is_empty())) { // There will be no change diff --git a/cpp/src/replace/nulls.cu b/cpp/src/replace/nulls.cu index fe3d20e372e..13e130588c1 100644 --- a/cpp/src/replace/nulls.cu +++ b/cpp/src/replace/nulls.cu @@ -38,6 +38,7 @@ #include #include #include +#include #include #include @@ -216,7 +217,8 @@ struct replace_nulls_scalar_kernel_forwarder { rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == replacement.type(), "Data type mismatch"); + CUDF_EXPECTS( + cudf::have_same_types(input, replacement), "Data type mismatch", cudf::data_type_error); std::unique_ptr output = cudf::detail::allocate_like( input, input.size(), cudf::mask_allocation_policy::NEVER, stream, mr); auto output_view = output->mutable_view(); @@ -252,9 +254,10 @@ std::unique_ptr replace_nulls_scalar_kernel_forwarder::operator()< rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == replacement.type(), "Data type mismatch"); + CUDF_EXPECTS( + cudf::have_same_types(input, replacement), "Data type mismatch", cudf::data_type_error); cudf::strings_column_view input_s(input); - cudf::string_scalar const& repl = static_cast(replacement); + auto const& repl = static_cast(replacement); return cudf::strings::detail::replace_nulls(input_s, repl, stream, mr); } @@ -318,7 +321,8 @@ std::unique_ptr replace_nulls(cudf::column_view const& input, rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == replacement.type(), "Data type mismatch"); + CUDF_EXPECTS( + cudf::have_same_types(input, replacement), "Data type mismatch", cudf::data_type_error); CUDF_EXPECTS(replacement.size() == input.size(), "Column size mismatch"); if (input.is_empty()) { return cudf::empty_like(input); } diff --git a/cpp/src/replace/replace.cu b/cpp/src/replace/replace.cu index 7bc0bd7e0be..c2cd03cd761 100644 --- a/cpp/src/replace/replace.cu +++ b/cpp/src/replace/replace.cu @@ -48,6 +48,7 @@ #include #include #include +#include #include #include @@ -303,9 +304,10 @@ std::unique_ptr find_and_replace_all(cudf::column_view const& inpu CUDF_EXPECTS(values_to_replace.size() == replacement_values.size(), "values_to_replace and replacement_values size mismatch."); - CUDF_EXPECTS( - input_col.type() == values_to_replace.type() && input_col.type() == replacement_values.type(), - "Columns type mismatch"); + CUDF_EXPECTS(cudf::have_same_types(input_col, values_to_replace) && + cudf::have_same_types(input_col, replacement_values), + "Columns type mismatch", + cudf::data_type_error); CUDF_EXPECTS(not values_to_replace.has_nulls(), "values_to_replace must not have nulls"); if (input_col.is_empty() or values_to_replace.is_empty() or replacement_values.is_empty()) { diff --git a/cpp/src/rolling/detail/lead_lag_nested.cuh b/cpp/src/rolling/detail/lead_lag_nested.cuh index 269868910c7..cfedcac8ae4 100644 --- a/cpp/src/rolling/detail/lead_lag_nested.cuh +++ b/cpp/src/rolling/detail/lead_lag_nested.cuh @@ -23,7 +23,9 @@ #include #include #include +#include #include +#include #include #include @@ -99,8 +101,9 @@ std::unique_ptr compute_lead_lag_for_nested(aggregation::Kind op, { CUDF_EXPECTS(op == aggregation::LEAD || op == aggregation::LAG, "Unexpected aggregation type in compute_lead_lag_for_nested"); - CUDF_EXPECTS(default_outputs.type().id() == input.type().id(), - "Defaults column type must match input column."); // Because LEAD/LAG. + CUDF_EXPECTS(cudf::have_same_types(input, default_outputs), + "Defaults column type must match input column.", + cudf::data_type_error); // Because LEAD/LAG. CUDF_EXPECTS(default_outputs.is_empty() || (input.size() == default_outputs.size()), "Number of defaults must match input column."); diff --git a/cpp/src/search/contains_scalar.cu b/cpp/src/search/contains_scalar.cu index 0b344ec347b..e88acf68e28 100644 --- a/cpp/src/search/contains_scalar.cu +++ b/cpp/src/search/contains_scalar.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * Copyright (c) 2019-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ #include #include #include +#include +#include #include #include @@ -62,7 +64,9 @@ struct contains_scalar_dispatch { scalar const& needle, rmm::cuda_stream_view stream) const { - CUDF_EXPECTS(haystack.type() == needle.type(), "Scalar and column types must match"); + CUDF_EXPECTS(cudf::have_same_types(haystack, needle), + "Scalar and column types must match", + cudf::data_type_error); // Don't need to check for needle validity. If it is invalid, it should be handled by the caller // before dispatching to this function. @@ -87,7 +91,9 @@ struct contains_scalar_dispatch { scalar const& needle, rmm::cuda_stream_view stream) const { - CUDF_EXPECTS(haystack.type() == needle.type(), "Scalar and column types must match"); + CUDF_EXPECTS(cudf::have_same_types(haystack, needle), + "Scalar and column types must match", + cudf::data_type_error); // Don't need to check for needle validity. If it is invalid, it should be handled by the caller // before dispatching to this function. // In addition, haystack and needle structure compatibility will be checked later on by diff --git a/cpp/src/search/contains_table.cu b/cpp/src/search/contains_table.cu index 13417fdab63..466f9093194 100644 --- a/cpp/src/search/contains_table.cu +++ b/cpp/src/search/contains_table.cu @@ -22,6 +22,7 @@ #include #include #include +#include #include #include diff --git a/cpp/src/strings/slice.cu b/cpp/src/strings/slice.cu index 2f7564b3b0d..972a4ffd58e 100644 --- a/cpp/src/strings/slice.cu +++ b/cpp/src/strings/slice.cu @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -228,13 +230,17 @@ std::unique_ptr slice_strings(strings_column_view const& strings, "Parameter starts must have the same number of rows as strings."); CUDF_EXPECTS(stops_column.size() == strings_count, "Parameter stops must have the same number of rows as strings."); - CUDF_EXPECTS(starts_column.type() == stops_column.type(), - "Parameters starts and stops must be of the same type."); + CUDF_EXPECTS(cudf::have_same_types(starts_column, stops_column), + "Parameters starts and stops must be of the same type.", + cudf::data_type_error); CUDF_EXPECTS(starts_column.null_count() == 0, "Parameter starts must not contain nulls."); CUDF_EXPECTS(stops_column.null_count() == 0, "Parameter stops must not contain nulls."); CUDF_EXPECTS(starts_column.type().id() != data_type{type_id::BOOL8}.id(), - "Positions values must not be bool type."); - CUDF_EXPECTS(is_fixed_width(starts_column.type()), "Positions values must be fixed width type."); + "Positions values must not be bool type.", + cudf::data_type_error); + CUDF_EXPECTS(is_fixed_width(starts_column.type()), + "Positions values must be fixed width type.", + cudf::data_type_error); auto strings_column = column_device_view::create(strings.parent(), stream); auto starts_iter = cudf::detail::indexalator_factory::make_input_iterator(starts_column); diff --git a/cpp/src/table/table_view.cpp b/cpp/src/table/table_view.cpp index bcbf2d44139..13832b0d9dc 100644 --- a/cpp/src/table/table_view.cpp +++ b/cpp/src/table/table_view.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * Copyright (c) 2018-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,30 +145,21 @@ bool has_nested_nullable_columns(table_view const& input) }); } -bool have_same_types(table_view const& lhs, table_view const& rhs) +namespace detail { + +template +bool is_relationally_comparable(TableView const& lhs, TableView const& rhs) { return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end(), [](column_view const& lcol, column_view const& rcol) { - return cudf::column_types_equal(lcol, rcol); + return cudf::is_relationally_comparable(lcol.type()) and + cudf::have_same_types(lcol, rcol); }); } -namespace detail { - -template -bool is_relationally_comparable(TableView const& lhs, TableView const& rhs) -{ - return std::all_of(thrust::counting_iterator(0), - thrust::counting_iterator(lhs.num_columns()), - [lhs, rhs](auto const i) { - return lhs.column(i).type() == rhs.column(i).type() and - cudf::is_relationally_comparable(lhs.column(i).type()); - }); -} - // Explicit template instantiation for a table of immutable views template bool is_relationally_comparable(table_view const& lhs, table_view const& rhs); diff --git a/cpp/src/transform/one_hot_encode.cu b/cpp/src/transform/one_hot_encode.cu index 570060b3870..723c306da1d 100644 --- a/cpp/src/transform/one_hot_encode.cu +++ b/cpp/src/transform/one_hot_encode.cu @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -61,7 +62,9 @@ std::pair, table_view> one_hot_encode(column_view const& rmm::cuda_stream_view stream, rmm::device_async_resource_ref mr) { - CUDF_EXPECTS(input.type() == categories.type(), "Mismatch type between input and categories."); + CUDF_EXPECTS(cudf::have_same_types(input, categories), + "Mismatch type between input and categories.", + cudf::data_type_error); if (categories.is_empty()) { return {make_empty_column(type_id::BOOL8), table_view{}}; } diff --git a/cpp/src/utilities/type_checks.cpp b/cpp/src/utilities/type_checks.cpp index d6f5c65593a..dac981fb532 100644 --- a/cpp/src/utilities/type_checks.cpp +++ b/cpp/src/utilities/type_checks.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,8 @@ #include #include +#include +#include #include #include @@ -28,15 +30,16 @@ namespace { struct columns_equal_fn { template - bool operator()(column_view const&, column_view const&) + bool operator()(column_view const& lhs, column_view const& rhs) { - return true; + return lhs.type() == rhs.type(); } }; template <> bool columns_equal_fn::operator()(column_view const& lhs, column_view const& rhs) { + if (not cudf::is_dictionary(rhs.type())) { return false; } auto const kidx = dictionary_column_view::keys_column_index; return lhs.num_children() > 0 and rhs.num_children() > 0 ? lhs.child(kidx).type() == rhs.child(kidx).type() @@ -46,33 +49,132 @@ bool columns_equal_fn::operator()(column_view const& lhs, column_v template <> bool columns_equal_fn::operator()(column_view const& lhs, column_view const& rhs) { + if (rhs.type().id() != type_id::LIST) { return false; } auto const& ci = lists_column_view::child_column_index; - return column_types_equal(lhs.child(ci), rhs.child(ci)); + return have_same_types(lhs.child(ci), rhs.child(ci)); } template <> bool columns_equal_fn::operator()(column_view const& lhs, column_view const& rhs) { - return lhs.num_children() == rhs.num_children() and - std::all_of(thrust::make_counting_iterator(0), - thrust::make_counting_iterator(lhs.num_children()), - [&](auto i) { return column_types_equal(lhs.child(i), rhs.child(i)); }); + if (rhs.type().id() != type_id::STRUCT) { return false; } + return std::equal(lhs.child_begin(), + lhs.child_end(), + rhs.child_begin(), + rhs.child_end(), + [](auto const& lhs, auto const& rhs) { return have_same_types(lhs, rhs); }); +} + +struct column_scalar_equal_fn { + template + bool operator()(column_view const& col, scalar const& slr) + { + return col.type() == slr.type(); + } +}; + +template <> +bool column_scalar_equal_fn::operator()(column_view const& col, scalar const& slr) +{ + // It is not possible to have a scalar dictionary, so compare the dictionary + // column keys type to the scalar type. + auto col_keys = cudf::dictionary_column_view(col).keys(); + return have_same_types(col_keys, slr); +} + +template <> +bool column_scalar_equal_fn::operator()(column_view const& col, scalar const& slr) +{ + if (slr.type().id() != type_id::LIST) { return false; } + auto const& ci = lists_column_view::child_column_index; + auto const list_slr = static_cast(&slr); + return have_same_types(col.child(ci), list_slr->view()); +} + +template <> +bool column_scalar_equal_fn::operator()(column_view const& col, scalar const& slr) +{ + if (slr.type().id() != type_id::STRUCT) { return false; } + auto const struct_slr = static_cast(&slr); + auto const slr_tbl = struct_slr->view(); + return std::equal(col.child_begin(), + col.child_end(), + slr_tbl.begin(), + slr_tbl.end(), + [](auto const& lhs, auto const& rhs) { return have_same_types(lhs, rhs); }); +} + +struct scalars_equal_fn { + template + bool operator()(scalar const& lhs, scalar const& rhs) + { + return lhs.type() == rhs.type(); + } +}; + +template <> +bool scalars_equal_fn::operator()(scalar const& lhs, scalar const& rhs) +{ + if (rhs.type().id() != type_id::LIST) { return false; } + auto const list_lhs = static_cast(&lhs); + auto const list_rhs = static_cast(&rhs); + return have_same_types(list_lhs->view(), list_rhs->view()); +} + +template <> +bool scalars_equal_fn::operator()(scalar const& lhs, scalar const& rhs) +{ + if (rhs.type().id() != type_id::STRUCT) { return false; } + auto const tbl_lhs = static_cast(&lhs)->view(); + auto const tbl_rhs = static_cast(&rhs)->view(); + return have_same_types(tbl_lhs, tbl_rhs); } }; // namespace // Implementation note: avoid using double dispatch for this function // as it increases code paths to NxN for N types. -bool column_types_equal(column_view const& lhs, column_view const& rhs) +bool have_same_types(column_view const& lhs, column_view const& rhs) { - if (lhs.type() != rhs.type()) { return false; } return type_dispatcher(lhs.type(), columns_equal_fn{}, lhs, rhs); } +bool column_types_equal(column_view const& lhs, column_view const& rhs) +{ + return have_same_types(lhs, rhs); +} + +bool have_same_types(column_view const& lhs, scalar const& rhs) +{ + return type_dispatcher(lhs.type(), column_scalar_equal_fn{}, lhs, rhs); +} + +bool have_same_types(scalar const& lhs, column_view const& rhs) +{ + return have_same_types(rhs, lhs); +} + +bool have_same_types(scalar const& lhs, scalar const& rhs) +{ + return type_dispatcher(lhs.type(), scalars_equal_fn{}, lhs, rhs); +} + +bool have_same_types(table_view const& lhs, table_view const& rhs) +{ + return std::equal( + lhs.begin(), + lhs.end(), + rhs.begin(), + rhs.end(), + [](column_view const& lcol, column_view const& rcol) { return have_same_types(lcol, rcol); }); +} + bool column_types_equivalent(column_view const& lhs, column_view const& rhs) { - if (lhs.type().id() != rhs.type().id()) { return false; } - return type_dispatcher(lhs.type(), columns_equal_fn{}, lhs, rhs); + // Check if the columns have fixed point types. This is the only case where + // type equality and equivalence differ. + if (cudf::is_fixed_point(lhs.type())) { return lhs.type().id() == rhs.type().id(); } + return have_same_types(lhs, rhs); } } // namespace cudf diff --git a/cpp/tests/copying/concatenate_tests.cpp b/cpp/tests/copying/concatenate_tests.cpp index c2d1e1d9f4f..a9bf22682cf 100644 --- a/cpp/tests/copying/concatenate_tests.cpp +++ b/cpp/tests/copying/concatenate_tests.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include @@ -1226,7 +1227,7 @@ TEST_F(ListsColumnTest, ConcatenateMismatchedHierarchies) cudf::test::lists_column_wrapper b{{{LCW{}}}}; cudf::test::lists_column_wrapper c{{LCW{}}}; - EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::data_type_error); } { @@ -1235,7 +1236,7 @@ TEST_F(ListsColumnTest, ConcatenateMismatchedHierarchies) cudf::test::lists_column_wrapper b{{{LCW{}}}}; cudf::test::lists_column_wrapper c{{LCW{}}}; - EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::data_type_error); } { @@ -1243,14 +1244,14 @@ TEST_F(ListsColumnTest, ConcatenateMismatchedHierarchies) cudf::test::lists_column_wrapper b{1, 2, 3}; cudf::test::lists_column_wrapper c{{3, 4, 5}}; - EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(std::vector({a, b, c})), cudf::data_type_error); } { cudf::test::lists_column_wrapper a{{{1, 2, 3}}}; cudf::test::lists_column_wrapper b{{4, 5}}; - EXPECT_THROW(cudf::concatenate(std::vector({a, b})), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(std::vector({a, b})), cudf::data_type_error); } } @@ -1605,7 +1606,7 @@ TEST_F(FixedPointTest, FixedPointScaleMismatch) auto const b = fp_wrapper(vec.begin() + 300, vec.begin() + 700, scale_type{-2}); auto const c = fp_wrapper(vec.begin() + 700, vec.end(), /*****/ scale_type{-3}); - EXPECT_THROW(cudf::concatenate(std::vector{a, b, c}), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(std::vector{a, b, c}), cudf::data_type_error); } struct DictionaryConcatTest : public cudf::test::BaseFixture {}; @@ -1650,7 +1651,7 @@ TEST_F(DictionaryConcatTest, ErrorsTest) cudf::test::fixed_width_column_wrapper integers({10, 30, 20}); auto dictionary2 = cudf::dictionary::encode(integers); std::vector views({dictionary1->view(), dictionary2->view()}); - EXPECT_THROW(cudf::concatenate(views), cudf::logic_error); + EXPECT_THROW(cudf::concatenate(views), cudf::data_type_error); std::vector empty; EXPECT_THROW(cudf::concatenate(empty), cudf::logic_error); } diff --git a/cpp/tests/copying/copy_range_tests.cpp b/cpp/tests/copying/copy_range_tests.cpp index bcc0ac29b3e..223946ddcee 100644 --- a/cpp/tests/copying/copy_range_tests.cpp +++ b/cpp/tests/copying/copy_range_tests.cpp @@ -465,7 +465,7 @@ TEST_F(CopyRangeErrorTestFixture, DTypeMismatch) auto dict_target = cudf::dictionary::encode(target); auto dict_source = cudf::dictionary::encode(source); EXPECT_THROW(cudf::copy_range(dict_source->view(), dict_target->view(), 0, 100, 0), - cudf::logic_error); + cudf::data_type_error); } template diff --git a/cpp/tests/copying/copy_tests.cpp b/cpp/tests/copying/copy_tests.cpp index 138e1935363..f31d8d6f79a 100644 --- a/cpp/tests/copying/copy_tests.cpp +++ b/cpp/tests/copying/copy_tests.cpp @@ -712,7 +712,7 @@ TEST_F(DictionaryCopyIfElseTest, TypeMismatch) cudf::test::dictionary_column_wrapper input2({1.0, 1.0, 1.0, 1.0}); cudf::test::fixed_width_column_wrapper mask({1, 0, 0, 1}); - EXPECT_THROW(cudf::copy_if_else(input1, input2, mask), cudf::logic_error); + EXPECT_THROW(cudf::copy_if_else(input1, input2, mask), cudf::data_type_error); cudf::string_scalar input3{"1"}; EXPECT_THROW(cudf::copy_if_else(input1, input3, mask), cudf::data_type_error); diff --git a/cpp/tests/copying/get_value_tests.cpp b/cpp/tests/copying/get_value_tests.cpp index 2be3c26af1d..99b86c86997 100644 --- a/cpp/tests/copying/get_value_tests.cpp +++ b/cpp/tests/copying/get_value_tests.cpp @@ -542,11 +542,6 @@ struct ListGetStructValueTest : public cudf::test::BaseFixture { return SCW{{field1, field2, field3}, mask}; } - /** - * @brief Create a 0-length structs column - */ - SCW zero_length_struct() { return SCW{}; } - /** * @brief Concatenate structs columns, allow specifying inputs in `initializer_list` */ @@ -653,7 +648,7 @@ TYPED_TEST(ListGetStructValueTest, NonNestedGetNonNullEmpty) cudf::size_type index = 2; // For well-formed list column, an empty list still holds the complete structure of // a 0-length structs column - auto expected_data = this->zero_length_struct(); + auto expected_data = this->make_test_structs_column({}, {}, {}, no_nulls()); auto s = cudf::get_element(list_column->view(), index); auto typed_s = static_cast(s.get()); @@ -757,8 +752,8 @@ TYPED_TEST(ListGetStructValueTest, NestedGetNonNullEmpty) auto list_column_nested = this->make_test_lists_column(3, {0, 1, 1, 2}, std::move(list_column), {1, 1, 1}); - auto expected_data = - this->make_test_lists_column(0, {0}, this->zero_length_struct().release(), {}); + auto expected_data = this->make_test_lists_column( + 0, {0}, this->make_test_structs_column({}, {}, {}, no_nulls()).release(), {}); cudf::size_type index = 1; auto s = cudf::get_element(list_column_nested->view(), index); diff --git a/cpp/tests/dictionary/add_keys_test.cpp b/cpp/tests/dictionary/add_keys_test.cpp index 1314375f383..46bf5468922 100644 --- a/cpp/tests/dictionary/add_keys_test.cpp +++ b/cpp/tests/dictionary/add_keys_test.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -83,7 +84,7 @@ TEST_F(DictionaryAddKeysTest, Errors) auto dictionary = cudf::dictionary::encode(input); cudf::test::fixed_width_column_wrapper new_keys{1.0, 2.0, 3.0}; - EXPECT_THROW(cudf::dictionary::add_keys(dictionary->view(), new_keys), cudf::logic_error); + EXPECT_THROW(cudf::dictionary::add_keys(dictionary->view(), new_keys), cudf::data_type_error); cudf::test::fixed_width_column_wrapper null_keys{{1, 2, 3}, {1, 0, 1}}; EXPECT_THROW(cudf::dictionary::add_keys(dictionary->view(), null_keys), cudf::logic_error); } diff --git a/cpp/tests/dictionary/remove_keys_test.cpp b/cpp/tests/dictionary/remove_keys_test.cpp index 13fe3efd0f4..9950a39d630 100644 --- a/cpp/tests/dictionary/remove_keys_test.cpp +++ b/cpp/tests/dictionary/remove_keys_test.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -119,7 +120,7 @@ TEST_F(DictionaryRemoveKeysTest, Errors) auto const dictionary = cudf::dictionary::encode(input); cudf::test::fixed_width_column_wrapper del_keys{1.0, 2.0, 3.0}; - EXPECT_THROW(cudf::dictionary::remove_keys(dictionary->view(), del_keys), cudf::logic_error); + EXPECT_THROW(cudf::dictionary::remove_keys(dictionary->view(), del_keys), cudf::data_type_error); cudf::test::fixed_width_column_wrapper null_keys{{1, 2, 3}, {1, 0, 1}}; EXPECT_THROW(cudf::dictionary::remove_keys(dictionary->view(), null_keys), cudf::logic_error); } diff --git a/cpp/tests/dictionary/scatter_test.cpp b/cpp/tests/dictionary/scatter_test.cpp index 2a2841827d0..2f77f4ee621 100644 --- a/cpp/tests/dictionary/scatter_test.cpp +++ b/cpp/tests/dictionary/scatter_test.cpp @@ -141,5 +141,5 @@ TEST_F(DictionaryScatterTest, Error) EXPECT_THROW( cudf::scatter( cudf::table_view{{source->view()}}, scatter_map, cudf::table_view{{target->view()}}), - cudf::logic_error); + cudf::data_type_error); } diff --git a/cpp/tests/dictionary/search_test.cpp b/cpp/tests/dictionary/search_test.cpp index 600d00ac186..b49b4ce5aa0 100644 --- a/cpp/tests/dictionary/search_test.cpp +++ b/cpp/tests/dictionary/search_test.cpp @@ -77,9 +77,9 @@ TEST_F(DictionarySearchTest, Errors) { cudf::test::dictionary_column_wrapper dictionary({1, 2, 3}); cudf::numeric_scalar key(7); - EXPECT_THROW(cudf::dictionary::get_index(dictionary, key), cudf::logic_error); + EXPECT_THROW(cudf::dictionary::get_index(dictionary, key), cudf::data_type_error); EXPECT_THROW( cudf::dictionary::detail::get_insert_index( dictionary, key, cudf::get_default_stream(), rmm::mr::get_current_device_resource()), - cudf::logic_error); + cudf::data_type_error); } diff --git a/cpp/tests/dictionary/set_keys_test.cpp b/cpp/tests/dictionary/set_keys_test.cpp index d0c37493cf8..5c9ec3567fe 100644 --- a/cpp/tests/dictionary/set_keys_test.cpp +++ b/cpp/tests/dictionary/set_keys_test.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include @@ -82,7 +83,7 @@ TEST_F(DictionarySetKeysTest, Errors) auto dictionary = cudf::dictionary::encode(input); cudf::test::fixed_width_column_wrapper new_keys{1.0, 2.0, 3.0}; - EXPECT_THROW(cudf::dictionary::set_keys(dictionary->view(), new_keys), cudf::logic_error); + EXPECT_THROW(cudf::dictionary::set_keys(dictionary->view(), new_keys), cudf::data_type_error); cudf::test::fixed_width_column_wrapper null_keys{{1, 2, 3}, {1, 0, 1}}; EXPECT_THROW(cudf::dictionary::set_keys(dictionary->view(), null_keys), cudf::logic_error); } diff --git a/cpp/tests/filling/fill_tests.cpp b/cpp/tests/filling/fill_tests.cpp index 95a27defa4e..26badefe698 100644 --- a/cpp/tests/filling/fill_tests.cpp +++ b/cpp/tests/filling/fill_tests.cpp @@ -359,8 +359,8 @@ TEST_F(FillErrorTestFixture, DTypeMismatch) auto destination_view = cudf::mutable_column_view{destination}; - EXPECT_THROW(cudf::fill_in_place(destination_view, 0, 10, *p_val), cudf::logic_error); - EXPECT_THROW(auto p_ret = cudf::fill(destination, 0, 10, *p_val), cudf::logic_error); + EXPECT_THROW(cudf::fill_in_place(destination_view, 0, 10, *p_val), cudf::data_type_error); + EXPECT_THROW(auto p_ret = cudf::fill(destination, 0, 10, *p_val), cudf::data_type_error); } template diff --git a/cpp/tests/filling/sequence_tests.cpp b/cpp/tests/filling/sequence_tests.cpp index cf619aace5a..5651a26f192 100644 --- a/cpp/tests/filling/sequence_tests.cpp +++ b/cpp/tests/filling/sequence_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * Copyright (c) 2020-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -102,15 +102,15 @@ TEST_F(SequenceTestFixture, MismatchedInputs) { cudf::numeric_scalar init(0); cudf::numeric_scalar step(-5); - EXPECT_THROW(cudf::sequence(10, init, step), cudf::logic_error); + EXPECT_THROW(cudf::sequence(10, init, step), cudf::data_type_error); cudf::numeric_scalar init2(0); cudf::numeric_scalar step2(-5); - EXPECT_THROW(cudf::sequence(10, init2, step2), cudf::logic_error); + EXPECT_THROW(cudf::sequence(10, init2, step2), cudf::data_type_error); cudf::numeric_scalar init3(0); cudf::numeric_scalar step3(-5); - EXPECT_THROW(cudf::sequence(10, init3, step3), cudf::logic_error); + EXPECT_THROW(cudf::sequence(10, init3, step3), cudf::data_type_error); } TYPED_TEST(SequenceTypedTestFixture, DefaultStep) diff --git a/cpp/tests/groupby/shift_tests.cpp b/cpp/tests/groupby/shift_tests.cpp index d2ecb667eca..1a6abf2e734 100644 --- a/cpp/tests/groupby/shift_tests.cpp +++ b/cpp/tests/groupby/shift_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -507,7 +507,7 @@ TEST_F(groupby_shift_fixed_point_type_test, MismatchScaleType) EXPECT_THROW(test_groupby_shift_multi( key, cudf::table_view{{v1}}, offset, {*slr1}, cudf::table_view{{stub}}), - cudf::logic_error); + cudf::data_type_error); } TEST_F(groupby_shift_fixed_point_type_test, MismatchRepType) @@ -525,5 +525,5 @@ TEST_F(groupby_shift_fixed_point_type_test, MismatchRepType) EXPECT_THROW(test_groupby_shift_multi( key, cudf::table_view{{v1}}, offset, {*slr1}, cudf::table_view{{stub}}), - cudf::logic_error); + cudf::data_type_error); } diff --git a/cpp/tests/interop/dlpack_test.cpp b/cpp/tests/interop/dlpack_test.cpp index 895887ee348..ecc8558243d 100644 --- a/cpp/tests/interop/dlpack_test.cpp +++ b/cpp/tests/interop/dlpack_test.cpp @@ -20,6 +20,7 @@ #include #include +#include #include @@ -98,7 +99,7 @@ TEST_F(DLPackUntypedTests, MultipleTypesToDlpack) cudf::test::fixed_width_column_wrapper col1({1, 2, 3, 4}); cudf::test::fixed_width_column_wrapper col2({1, 2, 3, 4}); cudf::table_view input({col1, col2}); - EXPECT_THROW(cudf::to_dlpack(input), cudf::logic_error); + EXPECT_THROW(cudf::to_dlpack(input), cudf::data_type_error); } TEST_F(DLPackUntypedTests, InvalidNullsToDlpack) diff --git a/cpp/tests/io/parquet_writer_test.cpp b/cpp/tests/io/parquet_writer_test.cpp index 3a8763ed9f3..fd8484bc70f 100644 --- a/cpp/tests/io/parquet_writer_test.cpp +++ b/cpp/tests/io/parquet_writer_test.cpp @@ -567,9 +567,7 @@ TEST_F(ParquetWriterTest, EmptyList) auto result = cudf::io::read_parquet( cudf::io::parquet_reader_options_builder(cudf::io::source_info(filepath))); - using lcw = cudf::test::lists_column_wrapper; - auto expected = lcw{lcw{}, lcw{}, lcw{}}; - CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tbl->view().column(0), expected); + CUDF_TEST_EXPECT_COLUMNS_EQUAL(result.tbl->view().column(0), L0->view()); } TEST_F(ParquetWriterTest, DeepEmptyList) diff --git a/cpp/tests/labeling/label_bins_tests.cpp b/cpp/tests/labeling/label_bins_tests.cpp index 2ac6ad5dd0d..1a9e74df9be 100644 --- a/cpp/tests/labeling/label_bins_tests.cpp +++ b/cpp/tests/labeling/label_bins_tests.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -64,7 +65,7 @@ TEST(BinColumnErrorTests, TestInvalidLeft) EXPECT_THROW( cudf::label_bins(input, left_edges, cudf::inclusive::YES, right_edges, cudf::inclusive::NO), - cudf::logic_error); + cudf::data_type_error); }; // Right edges type check. @@ -76,7 +77,7 @@ TEST(BinColumnErrorTests, TestInvalidRight) EXPECT_THROW( cudf::label_bins(input, left_edges, cudf::inclusive::YES, right_edges, cudf::inclusive::NO), - cudf::logic_error); + cudf::data_type_error); }; // Input type check. @@ -88,7 +89,7 @@ TEST(BinColumnErrorTests, TestInvalidInput) EXPECT_THROW( cudf::label_bins(input, left_edges, cudf::inclusive::YES, right_edges, cudf::inclusive::NO), - cudf::logic_error); + cudf::data_type_error); }; // Number of left and right edges must match. diff --git a/cpp/tests/lists/combine/concatenate_rows_tests.cpp b/cpp/tests/lists/combine/concatenate_rows_tests.cpp index 008003a08a1..bf088eb855a 100644 --- a/cpp/tests/lists/combine/concatenate_rows_tests.cpp +++ b/cpp/tests/lists/combine/concatenate_rows_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include +#include using namespace cudf::test::iterators; @@ -53,7 +54,7 @@ TEST_F(ListConcatenateRowsTest, InvalidInput) auto const col1 = IntListsCol{}.release(); auto const col2 = StrListsCol{}.release(); EXPECT_THROW(cudf::lists::concatenate_rows(TView{{col1->view(), col2->view()}}), - cudf::logic_error); + cudf::data_type_error); } } diff --git a/cpp/tests/lists/sequences_tests.cpp b/cpp/tests/lists/sequences_tests.cpp index e97600a76d3..74545903eb3 100644 --- a/cpp/tests/lists/sequences_tests.cpp +++ b/cpp/tests/lists/sequences_tests.cpp @@ -22,6 +22,7 @@ #include #include +#include using namespace cudf::test::iterators; @@ -200,8 +201,8 @@ TEST_F(NumericSequencesTest, InvalidSizesInput) auto const steps = IntsCol{}; auto const sizes = FWDCol{}; - EXPECT_THROW(cudf::lists::sequences(starts, sizes), cudf::logic_error); - EXPECT_THROW(cudf::lists::sequences(starts, steps, sizes), cudf::logic_error); + EXPECT_THROW(cudf::lists::sequences(starts, sizes), cudf::data_type_error); + EXPECT_THROW(cudf::lists::sequences(starts, steps, sizes), cudf::data_type_error); } TEST_F(NumericSequencesTest, MismatchedColumnSizesInput) @@ -220,7 +221,7 @@ TEST_F(NumericSequencesTest, MismatchedColumnTypesInput) auto const steps = FWDCol{1, 2, 3}; auto const sizes = IntsCol{1, 2, 3}; - EXPECT_THROW(cudf::lists::sequences(starts, steps, sizes), cudf::logic_error); + EXPECT_THROW(cudf::lists::sequences(starts, steps, sizes), cudf::data_type_error); } TEST_F(NumericSequencesTest, InputHasNulls) diff --git a/cpp/tests/replace/clamp_test.cpp b/cpp/tests/replace/clamp_test.cpp index bb33de1f1e7..239c9ce6ddd 100644 --- a/cpp/tests/replace/clamp_test.cpp +++ b/cpp/tests/replace/clamp_test.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include @@ -41,7 +42,7 @@ TEST_F(ClampErrorTest, MisMatchingScalarTypes) cudf::test::fixed_width_column_wrapper input({1, 2, 3, 4, 5, 6}); - EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::logic_error); + EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::data_type_error); } TEST_F(ClampErrorTest, MisMatchingInputAndScalarTypes) @@ -53,7 +54,7 @@ TEST_F(ClampErrorTest, MisMatchingInputAndScalarTypes) cudf::test::fixed_width_column_wrapper input({1, 2, 3, 4, 5, 6}); - EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::logic_error); + EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::data_type_error); } TEST_F(ClampErrorTest, MisMatchingReplaceScalarTypes) @@ -69,7 +70,7 @@ TEST_F(ClampErrorTest, MisMatchingReplaceScalarTypes) cudf::test::fixed_width_column_wrapper input({1, 2, 3, 4, 5, 6}); - EXPECT_THROW(cudf::clamp(input, *lo, *lo_replace, *hi, *hi_replace), cudf::logic_error); + EXPECT_THROW(cudf::clamp(input, *lo, *lo_replace, *hi, *hi_replace), cudf::data_type_error); } TEST_F(ClampErrorTest, InValidCase1) @@ -640,7 +641,7 @@ TYPED_TEST(FixedPointTest, MismatchedScalarScales) auto const hi = cudf::make_fixed_point_scalar(8, scale); auto const input = fp_wrapper{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, scale}; - EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::logic_error); + EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::data_type_error); } TYPED_TEST(FixedPointTest, MismatchedColumnScalarScale) @@ -655,7 +656,7 @@ TYPED_TEST(FixedPointTest, MismatchedColumnScalarScale) auto const hi = cudf::make_fixed_point_scalar(8, scale); auto const input = fp_wrapper{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, scale_type{-4}}; - EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::logic_error); + EXPECT_THROW(cudf::clamp(input, *lo, *hi), cudf::data_type_error); } CUDF_TEST_PROGRAM_MAIN() diff --git a/cpp/tests/replace/replace_nulls_tests.cpp b/cpp/tests/replace/replace_nulls_tests.cpp index 6c23dd6bdc8..9603ea44a76 100644 --- a/cpp/tests/replace/replace_nulls_tests.cpp +++ b/cpp/tests/replace/replace_nulls_tests.cpp @@ -58,7 +58,7 @@ TEST_F(ReplaceErrorTest, TypeMismatch) cudf::test::fixed_width_column_wrapper values_to_replace_column{ {10, 11, 12, 13, 14, 15, 16, 17}}; - EXPECT_THROW(cudf::replace_nulls(input_column, values_to_replace_column), cudf::logic_error); + EXPECT_THROW(cudf::replace_nulls(input_column, values_to_replace_column), cudf::data_type_error); } // Error: column type mismatch @@ -68,7 +68,7 @@ TEST_F(ReplaceErrorTest, TypeMismatchScalar) {0, 0, 1, 1, 1, 1, 1, 1}}; cudf::numeric_scalar replacement(1); - EXPECT_THROW(cudf::replace_nulls(input_column, replacement), cudf::logic_error); + EXPECT_THROW(cudf::replace_nulls(input_column, replacement), cudf::data_type_error); } struct ReplaceNullsStringsTest : public cudf::test::BaseFixture {}; @@ -659,14 +659,14 @@ TEST_F(ReplaceDictionaryTest, ReplaceNullsError) cudf::test::fixed_width_column_wrapper replacement_w({1, 2, 3, 4}); auto replacement = cudf::dictionary::encode(replacement_w); - EXPECT_THROW(cudf::replace_nulls(input->view(), replacement->view()), cudf::logic_error); - EXPECT_THROW(cudf::replace_nulls(input->view(), cudf::string_scalar("x")), cudf::logic_error); + EXPECT_THROW(cudf::replace_nulls(input->view(), replacement->view()), cudf::data_type_error); + EXPECT_THROW(cudf::replace_nulls(input->view(), cudf::string_scalar("x")), cudf::data_type_error); cudf::test::fixed_width_column_wrapper input_one_w({1}, {0}); auto input_one = cudf::dictionary::encode(input_one_w); auto dict_input = cudf::dictionary_column_view(input_one->view()); auto dict_repl = cudf::dictionary_column_view(replacement->view()); - EXPECT_THROW(cudf::replace_nulls(input->view(), replacement->view()), cudf::logic_error); + EXPECT_THROW(cudf::replace_nulls(input->view(), replacement->view()), cudf::data_type_error); } TEST_F(ReplaceDictionaryTest, ReplaceNullsEmpty) diff --git a/cpp/tests/replace/replace_tests.cpp b/cpp/tests/replace/replace_tests.cpp index 613034efc12..1858cd7782e 100644 --- a/cpp/tests/replace/replace_tests.cpp +++ b/cpp/tests/replace/replace_tests.cpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -63,7 +64,7 @@ TEST_F(ReplaceErrorTest, TypeMismatch) EXPECT_THROW( cudf::find_and_replace_all(input_column, values_to_replace_column, replacement_values_column), - cudf::logic_error); + cudf::data_type_error); } // Error: nulls in old-values diff --git a/cpp/tests/transform/one_hot_encode_tests.cpp b/cpp/tests/transform/one_hot_encode_tests.cpp index 1015370fe4b..8384cb3480b 100644 --- a/cpp/tests/transform/one_hot_encode_tests.cpp +++ b/cpp/tests/transform/one_hot_encode_tests.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ #include #include +#include #include @@ -198,7 +199,7 @@ TEST_F(OneHotEncodingTest, MismatchTypes) auto input = cudf::test::strings_column_wrapper{"xx", "yy", "xx"}; auto category = cudf::test::fixed_width_column_wrapper{1}; - EXPECT_THROW(cudf::one_hot_encode(input, category), cudf::logic_error); + EXPECT_THROW(cudf::one_hot_encode(input, category), cudf::data_type_error); } TEST_F(OneHotEncodingTest, List) diff --git a/cpp/tests/utilities/column_utilities.cu b/cpp/tests/utilities/column_utilities.cu index 047b096a283..7cc2777972e 100644 --- a/cpp/tests/utilities/column_utilities.cu +++ b/cpp/tests/utilities/column_utilities.cu @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -238,11 +239,6 @@ std::unique_ptr generate_child_row_indices(lists_column_view const& c, template struct column_property_comparator { - bool types_equivalent(cudf::data_type const& lhs, cudf::data_type const& rhs) - { - return is_fixed_point(lhs) ? lhs.id() == rhs.id() : lhs == rhs; - } - bool compare_common(cudf::column_view const& lhs, cudf::column_view const& rhs, cudf::column_view const& lhs_row_indices, @@ -252,9 +248,9 @@ struct column_property_comparator { bool result = true; if (check_exact_equality) { - PROP_EXPECT_EQ(lhs.type(), rhs.type()); + PROP_EXPECT_EQ(cudf::have_same_types(lhs, rhs), true); } else { - PROP_EXPECT_EQ(types_equivalent(lhs.type(), rhs.type()), true); + PROP_EXPECT_EQ(cudf::column_types_equivalent(lhs, rhs), true); } auto const lhs_size = check_exact_equality ? lhs.size() : lhs_row_indices.size(); diff --git a/cpp/tests/utilities_tests/type_check_tests.cpp b/cpp/tests/utilities_tests/type_check_tests.cpp index 9c23798fce6..fecb896f95a 100644 --- a/cpp/tests/utilities_tests/type_check_tests.cpp +++ b/cpp/tests/utilities_tests/type_check_tests.cpp @@ -19,13 +19,11 @@ #include #include +#include #include #include #include -namespace cudf { -namespace test { - template struct ColumnTypeCheckTestTyped : public cudf::test::BaseFixture {}; @@ -35,56 +33,56 @@ TYPED_TEST_SUITE(ColumnTypeCheckTestTyped, cudf::test::FixedWidthTypes); TYPED_TEST(ColumnTypeCheckTestTyped, SameFixedWidth) { - fixed_width_column_wrapper lhs{1, 1}, rhs{2}; - EXPECT_TRUE(column_types_equal(lhs, rhs)); + cudf::test::fixed_width_column_wrapper lhs{1, 1}, rhs{2}; + EXPECT_TRUE(cudf::have_same_types(lhs, rhs)); } TEST_F(ColumnTypeCheckTest, SameString) { - strings_column_wrapper lhs{{'a', 'a'}}, rhs{{'b'}}; - EXPECT_TRUE(column_types_equal(lhs, rhs)); + cudf::test::strings_column_wrapper lhs{{'a', 'a'}}, rhs{{'b'}}; + EXPECT_TRUE(cudf::have_same_types(lhs, rhs)); - strings_column_wrapper lhs2{}, rhs2{{'b'}}; - EXPECT_TRUE(column_types_equal(lhs2, rhs2)); + cudf::test::strings_column_wrapper lhs2{}, rhs2{{'b'}}; + EXPECT_TRUE(cudf::have_same_types(lhs2, rhs2)); - strings_column_wrapper lhs3{}, rhs3{}; - EXPECT_TRUE(column_types_equal(lhs3, rhs3)); + cudf::test::strings_column_wrapper lhs3{}, rhs3{}; + EXPECT_TRUE(cudf::have_same_types(lhs3, rhs3)); } TEST_F(ColumnTypeCheckTest, SameList) { - using LCW = lists_column_wrapper; + using LCW = cudf::test::lists_column_wrapper; LCW lhs{}, rhs{}; - EXPECT_TRUE(column_types_equal(lhs, rhs)); + EXPECT_TRUE(cudf::have_same_types(lhs, rhs)); LCW lhs2{{1, 2, 3}}, rhs2{{4, 5}}; - EXPECT_TRUE(column_types_equal(lhs2, rhs2)); + EXPECT_TRUE(cudf::have_same_types(lhs2, rhs2)); LCW lhs3{{LCW{1}, LCW{2, 3}}}, rhs3{{LCW{4, 5}}}; - EXPECT_TRUE(column_types_equal(lhs3, rhs3)); + EXPECT_TRUE(cudf::have_same_types(lhs3, rhs3)); LCW lhs4{{LCW{1}, LCW{}, LCW{2, 3}}}, rhs4{{LCW{4, 5}, LCW{}}}; - EXPECT_TRUE(column_types_equal(lhs4, rhs4)); + EXPECT_TRUE(cudf::have_same_types(lhs4, rhs4)); } TYPED_TEST(ColumnTypeCheckTestTyped, SameDictionary) { - using DCW = dictionary_column_wrapper; + using DCW = cudf::test::dictionary_column_wrapper; DCW lhs{1, 1, 2, 3}, rhs{5, 5}; - EXPECT_TRUE(column_types_equal(lhs, rhs)); + EXPECT_TRUE(cudf::have_same_types(lhs, rhs)); DCW lhs2{}, rhs2{}; - EXPECT_TRUE(column_types_equal(lhs2, rhs2)); + EXPECT_TRUE(cudf::have_same_types(lhs2, rhs2)); } TEST_F(ColumnTypeCheckTest, SameStruct) { - using SCW = structs_column_wrapper; - using FCW = fixed_width_column_wrapper; - using StringCW = strings_column_wrapper; - using LCW = lists_column_wrapper; - using DCW = dictionary_column_wrapper; + using SCW = cudf::test::structs_column_wrapper; + using FCW = cudf::test::fixed_width_column_wrapper; + using StringCW = cudf::test::strings_column_wrapper; + using LCW = cudf::test::lists_column_wrapper; + using DCW = cudf::test::dictionary_column_wrapper; FCW lf1{1, 2, 3}, rf1{0, 1}; StringCW lf2{"a", "bb", ""}, rf2{"cc", "d"}; @@ -92,127 +90,158 @@ TEST_F(ColumnTypeCheckTest, SameStruct) DCW lf4{5, 5, 5}, rf4{9, 9}; SCW lhs{lf1, lf2, lf3, lf4}, rhs{rf1, rf2, rf3, rf4}; - EXPECT_TRUE(column_types_equal(lhs, rhs)); + EXPECT_TRUE(cudf::have_same_types(lhs, rhs)); } TEST_F(ColumnTypeCheckTest, DifferentBasics) { - fixed_width_column_wrapper lhs1{1, 1}; - strings_column_wrapper rhs1{"a", "bb"}; + cudf::test::fixed_width_column_wrapper lhs1{1, 1}; + cudf::test::strings_column_wrapper rhs1{"a", "bb"}; - EXPECT_FALSE(column_types_equal(lhs1, rhs1)); + EXPECT_FALSE(cudf::have_same_types(lhs1, rhs1)); - lists_column_wrapper lhs2{{"hello"}, {"world", "!"}}; - strings_column_wrapper rhs2{"", "kk"}; + cudf::test::lists_column_wrapper lhs2{{"hello"}, {"world", "!"}}; + cudf::test::strings_column_wrapper rhs2{"", "kk"}; - EXPECT_FALSE(column_types_equal(lhs2, rhs2)); + EXPECT_FALSE(cudf::have_same_types(lhs2, rhs2)); - fixed_width_column_wrapper lhs3{1, 1}; - dictionary_column_wrapper rhs3{2, 2}; + cudf::test::fixed_width_column_wrapper lhs3{1, 1}; + cudf::test::dictionary_column_wrapper rhs3{2, 2}; - EXPECT_FALSE(column_types_equal(lhs3, rhs3)); + EXPECT_FALSE(cudf::have_same_types(lhs3, rhs3)); - lists_column_wrapper lhs4{{8, 8, 8}, {10, 10}}; - structs_column_wrapper rhs4{rhs2, rhs3}; + cudf::test::lists_column_wrapper lhs4{{8, 8, 8}, {10, 10}}; + cudf::test::structs_column_wrapper rhs4{rhs2, rhs3}; - EXPECT_FALSE(column_types_equal(lhs4, rhs4)); + EXPECT_FALSE(cudf::have_same_types(lhs4, rhs4)); } TEST_F(ColumnTypeCheckTest, DifferentFixedWidth) { - fixed_width_column_wrapper lhs1{1, 1}; - fixed_width_column_wrapper rhs1{2}; + cudf::test::fixed_width_column_wrapper lhs1{1, 1}; + cudf::test::fixed_width_column_wrapper rhs1{2}; - EXPECT_FALSE(column_types_equal(lhs1, rhs1)); + EXPECT_FALSE(cudf::have_same_types(lhs1, rhs1)); - fixed_width_column_wrapper lhs2{1, 1}; - fixed_width_column_wrapper rhs2{2}; + cudf::test::fixed_width_column_wrapper lhs2{1, 1}; + cudf::test::fixed_width_column_wrapper rhs2{2}; - EXPECT_FALSE(column_types_equal(lhs2, rhs2)); + EXPECT_FALSE(cudf::have_same_types(lhs2, rhs2)); - fixed_width_column_wrapper lhs3{1, 1}; - fixed_width_column_wrapper rhs3{2}; + cudf::test::fixed_width_column_wrapper lhs3{1, 1}; + cudf::test::fixed_width_column_wrapper rhs3{2}; - EXPECT_FALSE(column_types_equal(lhs3, rhs3)); + EXPECT_FALSE(cudf::have_same_types(lhs3, rhs3)); - fixed_width_column_wrapper lhs4{}; - fixed_width_column_wrapper rhs4{42}; + cudf::test::fixed_width_column_wrapper lhs4{}; + cudf::test::fixed_width_column_wrapper rhs4{42}; - EXPECT_FALSE(column_types_equal(lhs4, rhs4)); + EXPECT_FALSE(cudf::have_same_types(lhs4, rhs4)); // Same rep, different scale - fixed_point_column_wrapper lhs5({10000}, numeric::scale_type{-3}); - fixed_point_column_wrapper rhs5({10000}, numeric::scale_type{0}); + cudf::test::fixed_point_column_wrapper lhs5({10000}, numeric::scale_type{-3}); + cudf::test::fixed_point_column_wrapper rhs5({10000}, numeric::scale_type{0}); - EXPECT_FALSE(column_types_equal(lhs5, rhs5)); - EXPECT_TRUE(column_types_equivalent(lhs5, rhs5)); + EXPECT_FALSE(cudf::have_same_types(lhs5, rhs5)); + EXPECT_TRUE(cudf::column_types_equivalent(lhs5, rhs5)); // Different rep, same scale - fixed_point_column_wrapper lhs6({10000}, numeric::scale_type{-1}); - fixed_point_column_wrapper rhs6({4200}, numeric::scale_type{-1}); + cudf::test::fixed_point_column_wrapper lhs6({10000}, numeric::scale_type{-1}); + cudf::test::fixed_point_column_wrapper rhs6({4200}, numeric::scale_type{-1}); - EXPECT_FALSE(column_types_equal(lhs6, rhs6)); + EXPECT_FALSE(cudf::have_same_types(lhs6, rhs6)); } TEST_F(ColumnTypeCheckTest, DifferentDictionary) { - dictionary_column_wrapper lhs1{1, 1, 1, 2, 2, 3}; - dictionary_column_wrapper rhs1{0, 0, 42, 42}; + cudf::test::dictionary_column_wrapper lhs1{1, 1, 1, 2, 2, 3}; + cudf::test::dictionary_column_wrapper rhs1{0, 0, 42, 42}; - EXPECT_FALSE(column_types_equal(lhs1, rhs1)); + EXPECT_FALSE(cudf::have_same_types(lhs1, rhs1)); - dictionary_column_wrapper lhs2{3.14, 3.14, 5.00}; - dictionary_column_wrapper rhs2{0, 0, 42, 42}; + cudf::test::dictionary_column_wrapper lhs2{3.14, 3.14, 5.00}; + cudf::test::dictionary_column_wrapper rhs2{0, 0, 42, 42}; - EXPECT_FALSE(column_types_equal(lhs2, rhs2)); + EXPECT_FALSE(cudf::have_same_types(lhs2, rhs2)); - dictionary_column_wrapper lhs3{1, 1, 1, 2, 2, 3}; - dictionary_column_wrapper rhs3{8, 8}; + cudf::test::dictionary_column_wrapper lhs3{1, 1, 1, 2, 2, 3}; + cudf::test::dictionary_column_wrapper rhs3{8, 8}; - EXPECT_FALSE(column_types_equal(lhs3, rhs3)); + EXPECT_FALSE(cudf::have_same_types(lhs3, rhs3)); - dictionary_column_wrapper lhs4{1, 1, 2, 3}, rhs4{}; - EXPECT_FALSE(column_types_equal(lhs4, rhs4)); + cudf::test::dictionary_column_wrapper lhs4{1, 1, 2, 3}, rhs4{}; + EXPECT_FALSE(cudf::have_same_types(lhs4, rhs4)); } TEST_F(ColumnTypeCheckTest, DifferentLists) { - using LCW_i = lists_column_wrapper; - using LCW_f = lists_column_wrapper; + using LCW_i = cudf::test::lists_column_wrapper; + using LCW_f = cudf::test::lists_column_wrapper; // Different nested level LCW_i lhs1{LCW_i{1, 1, 2, 3}, LCW_i{}, LCW_i{42, 42}}; LCW_i rhs1{LCW_i{LCW_i{8, 8, 8}, LCW_i{9, 9}}, LCW_i{LCW_i{42, 42}}}; - EXPECT_FALSE(column_types_equal(lhs1, rhs1)); + EXPECT_FALSE(cudf::have_same_types(lhs1, rhs1)); // Different base column type LCW_i lhs2{LCW_i{1, 1, 2, 3}, LCW_i{}, LCW_i{42, 42}}; LCW_f rhs2{LCW_f{9.0, 9.1}, LCW_f{3.14}, LCW_f{}}; - EXPECT_FALSE(column_types_equal(lhs2, rhs2)); + EXPECT_FALSE(cudf::have_same_types(lhs2, rhs2)); } TEST_F(ColumnTypeCheckTest, DifferentStructs) { - fixed_width_column_wrapper lf1{1, 1, 1}; - fixed_width_column_wrapper rf1{2, 2}; + cudf::test::fixed_width_column_wrapper lf1{1, 1, 1}; + cudf::test::fixed_width_column_wrapper rf1{2, 2}; + + cudf::test::structs_column_wrapper lhs1{lf1}; + cudf::test::structs_column_wrapper rhs1{rf1}; - structs_column_wrapper lhs1{lf1}; - structs_column_wrapper rhs1{rf1}; + EXPECT_FALSE(cudf::have_same_types(lhs1, rhs1)); - EXPECT_FALSE(column_types_equal(lhs1, rhs1)); + cudf::test::fixed_width_column_wrapper lf2{1, 1, 1}; + cudf::test::fixed_width_column_wrapper rf2{2, 2}; - fixed_width_column_wrapper lf2{1, 1, 1}; - fixed_width_column_wrapper rf2{2, 2}; + cudf::test::strings_column_wrapper lf3{"a", "b", "c"}; - strings_column_wrapper lf3{"a", "b", "c"}; + cudf::test::structs_column_wrapper lhs2{lf2, lf3}; + cudf::test::structs_column_wrapper rhs2{rf2}; - structs_column_wrapper lhs2{lf2, lf3}; - structs_column_wrapper rhs2{rf2}; + EXPECT_FALSE(cudf::have_same_types(lhs2, rhs2)); +} - EXPECT_FALSE(column_types_equal(lhs2, rhs2)); +TYPED_TEST(ColumnTypeCheckTestTyped, AllTypesEqual) +{ + { + // An empty table + cudf::table_view tbl{}; + EXPECT_TRUE(cudf::all_have_same_types(tbl.begin(), tbl.end())); + } + + { + // A table with one column + cudf::test::fixed_width_column_wrapper col1{1, 2, 3}; + cudf::table_view tbl{{col1}}; + EXPECT_TRUE(cudf::all_have_same_types(tbl.begin(), tbl.end())); + } + + { + // A table with all the same types + cudf::test::fixed_width_column_wrapper col1{1, 2, 3}; + cudf::test::fixed_width_column_wrapper col2{4, 5, 6}; + cudf::test::fixed_width_column_wrapper col3{7, 8, 9}; + cudf::table_view tbl{{col1, col2, col3}}; + EXPECT_TRUE(cudf::all_have_same_types(tbl.begin(), tbl.end())); + } } -} // namespace test -} // namespace cudf +TEST_F(ColumnTypeCheckTest, AllTypesNotEqual) +{ + // A table with different types + cudf::test::fixed_width_column_wrapper col1{1, 2, 3}; + cudf::test::fixed_width_column_wrapper col2{3.14, 1.57, 2.71}; + cudf::table_view tbl{{col1, col2}}; + EXPECT_FALSE(cudf::all_have_same_types(tbl.begin(), tbl.end())); +}