-
Notifications
You must be signed in to change notification settings - Fork 203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose linalg::dot
in public API
#968
Changes from 1 commit
01dd067
e6a5bb1
400dfa9
f376c51
9c9efe8
2bf4eae
2e1c0e6
a668098
977949f
a125c4f
6f8a76c
af52c0c
25333ee
98f7d85
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
linalg::dot
in public API
Closes #805
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#ifndef __DOT_H | ||
#define __DOT_H | ||
|
||
#pragma once | ||
|
||
#include <raft/linalg/detail/cublas_wrappers.hpp> | ||
|
||
#include <raft/core/device_mdspan.hpp> | ||
#include <raft/core/handle.hpp> | ||
#include <raft/core/host_mdspan.hpp> | ||
|
||
namespace raft::linalg { | ||
|
||
/** | ||
* @brief Computes the dot product of two vectors. | ||
* @tparam InputType1 raft::device_mdspan for the first input vector | ||
* @tparam InputType2 raft::device_mdspan for the second input vector | ||
* @tparam OutputType Either a host_scalar_view or device_scalar_view for the output | ||
* @param[in] handle raft::handle_t | ||
* @param[in] x First input vector | ||
* @param[in] y Second input vector | ||
* @param[out] out The output dot product between the x and y vectors | ||
*/ | ||
template <typename InputType1, | ||
typename InputType2, | ||
typename OutputType, | ||
typename = raft::enable_if_input_device_mdspan<InputType1>, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I brought this up with the axpy as well, but it seems weird to accept a general mdspan for this when what we are really looking for is a 1d vector. Do you see value in accepting a matrix or dense tensor with 3+ dimensional extents? If not, we should just accept the vector_view directly (which is aliased to be any mdspan with 1d extents. If we accepted a device_vector_view directly, we wouldn't need the enable_if statements at all. I think we should go ahead and do the same for the axpy to keep things consistent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agreed - made the changes here so that both axpy and dot take device_vector_view's |
||
typename = raft::enable_if_input_device_mdspan<InputType2>, | ||
typename = raft::enable_if_output_mdspan<OutputType>> | ||
void dot(const raft::handle_t& handle, InputType1 x, InputType2 y, OutputType out) | ||
{ | ||
RAFT_EXPECTS(x.size() == y.size(), | ||
"Size mismatch between x and y input vectors in raft::linalg::dot"); | ||
|
||
// Right now the inputs and outputs need to all have the same value_type (float/double etc). | ||
// Try to output a meaningful compiler error if mismatched types are passed here. | ||
// Note: In the future we could remove this restriction using the cublasDotEx function | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we just go ahead and wrap the cublasEx functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created an issue so we can discuss further #977 . Reading the docs a little closer, and it looks like even w/ cublasDotEx having different dtypes for the input/outputs isn't currently supported: https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx - so it won't have much value for the dot API (though I could see a use for it myself with the gemm api w/ implicit and the mixed precision work I was talking about last week) |
||
// in the cublas wrapper call, instead of the cublassdot and cublasddot functions. | ||
static_assert(std::is_same_v<typename InputType1::value_type, typename InputType2::value_type>, | ||
"Both input vectors need to have the same value_type in raft::linalg::dot call"); | ||
static_assert( | ||
std::is_same_v<typename InputType1::value_type, typename OutputType::value_type>, | ||
"Input vectors and output scalar need to have the same value_type in raft::linalg::dot call"); | ||
|
||
RAFT_CUBLAS_TRY(detail::cublasdot(handle.get_cublas_handle(), | ||
x.size(), | ||
x.data_handle(), | ||
x.stride(0), | ||
y.data_handle(), | ||
y.stride(0), | ||
out.data_handle(), | ||
handle.get_stream())); | ||
} | ||
} // namespace raft::linalg | ||
#endif |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#include <raft/linalg/dot.cuh> | ||
|
||
#include "../test_utils.h" | ||
#include <gtest/gtest.h> | ||
#include <raft/random/rng.cuh> | ||
#include <raft/util/cuda_utils.cuh> | ||
#include <rmm/device_scalar.hpp> | ||
|
||
namespace raft { | ||
namespace linalg { | ||
|
||
// Reference dot implementation. | ||
template <typename T> | ||
__global__ void naiveDot(const int n, const T* x, int incx, const T* y, int incy, T* out) | ||
{ | ||
T sum = 0; | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) { | ||
sum += x[i * incx] * y[i * incy]; | ||
} | ||
atomicAdd(out, sum); | ||
} | ||
|
||
template <typename InType, typename IndexType = int, typename OutType = InType> | ||
struct DotInputs { | ||
OutType tolerance; | ||
IndexType len; | ||
IndexType incx; | ||
IndexType incy; | ||
unsigned long long int seed; | ||
}; | ||
|
||
template <typename T> | ||
class DotTest : public ::testing::TestWithParam<DotInputs<T>> { | ||
protected: | ||
raft::handle_t handle; | ||
DotInputs<T> params; | ||
rmm::device_scalar<T> output; | ||
rmm::device_scalar<T> refoutput; | ||
|
||
public: | ||
DotTest() | ||
: testing::TestWithParam<DotInputs<T>>(), | ||
output(0, handle.get_stream()), | ||
refoutput(0, handle.get_stream()) | ||
{ | ||
handle.sync_stream(); | ||
} | ||
|
||
protected: | ||
void SetUp() override | ||
{ | ||
params = ::testing::TestWithParam<DotInputs<T>>::GetParam(); | ||
|
||
cudaStream_t stream = handle.get_stream(); | ||
|
||
raft::random::RngState r(params.seed); | ||
|
||
int x_len = params.len * params.incx; | ||
int y_len = params.len * params.incy; | ||
|
||
rmm::device_uvector<T> x(x_len, stream); | ||
rmm::device_uvector<T> y(y_len, stream); | ||
uniform(handle, r, x.data(), x_len, T(-1.0), T(1.0)); | ||
uniform(handle, r, y.data(), y_len, T(-1.0), T(1.0)); | ||
|
||
naiveDot<<<256, 256, 0, stream>>>( | ||
params.len, x.data(), params.incx, y.data(), params.incy, refoutput.data()); | ||
|
||
auto out_view = make_device_scalar_view<T, int>(output.data()); | ||
|
||
if ((params.incx > 1) && (params.incy > 1)) { | ||
dot(handle, | ||
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx), | ||
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy), | ||
out_view); | ||
} else if (params.incx > 1) { | ||
dot(handle, | ||
make_strided_device_vector_view<const T>(x.data(), params.len, params.incx), | ||
make_device_vector_view<const T>(y.data(), params.len), | ||
out_view); | ||
} else if (params.incy > 1) { | ||
dot(handle, | ||
make_device_vector_view<const T>(x.data(), params.len), | ||
make_strided_device_vector_view<const T>(y.data(), params.len, params.incy), | ||
out_view); | ||
} else { | ||
dot(handle, | ||
make_device_vector_view<const T>(x.data(), params.len), | ||
make_device_vector_view<const T>(y.data(), params.len), | ||
out_view); | ||
} | ||
handle.sync_stream(); | ||
} | ||
|
||
void TearDown() override {} | ||
}; | ||
|
||
const std::vector<DotInputs<float>> inputsf = { | ||
{0.0001f, 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 98689, 1, 1, 1234ULL}, | ||
{0.0001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 4, 1, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 1, 3, 1234ULL}, | ||
{0.0001f, 1024 * 1024, 4, 3, 1234ULL}, | ||
}; | ||
|
||
const std::vector<DotInputs<double>> inputsd = { | ||
{0.000001f, 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 16 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 98689, 1, 1, 1234ULL}, | ||
{0.000001f, 4 * 1024 * 1024, 1, 1, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 4, 1, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 1, 3, 1234ULL}, | ||
{0.000001f, 1024 * 1024, 4, 3, 1234ULL}, | ||
}; | ||
|
||
typedef DotTest<float> DotTestF; | ||
TEST_P(DotTestF, Result) | ||
{ | ||
ASSERT_TRUE(raft::devArrMatch( | ||
refoutput.data(), output.data(), 1, raft::CompareApprox<float>(params.tolerance))); | ||
} | ||
|
||
typedef DotTest<double> DotTestD; | ||
TEST_P(DotTestD, Result) | ||
{ | ||
ASSERT_TRUE(raft::devArrMatch( | ||
refoutput.data(), output.data(), 1, raft::CompareApprox<double>(params.tolerance))); | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(DotTests, DotTestF, ::testing::ValuesIn(inputsf)); | ||
|
||
INSTANTIATE_TEST_SUITE_P(DotTests, DotTestD, ::testing::ValuesIn(inputsd)); | ||
|
||
} // end namespace linalg | ||
} // end namespace raft |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than adding another factory function for a strided vector, why not just allow a strided layout to be configured in the make_device_vector_view and make_host_vector_view?
Right now the make_*_vector_view automatically configures a row-major layout but the layout should really be configurable (and potentially strided, or col major if desired).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've updated make_device_vector_view to allow strided input here - let me know what you think.