Skip to content

Commit

Permalink
Migrate correlation kernel to DPC++ (#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
rlnx authored Feb 19, 2021
1 parent 263dd64 commit 6e5376c
Show file tree
Hide file tree
Showing 29 changed files with 1,264 additions and 249 deletions.
3 changes: 2 additions & 1 deletion .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ build --flag_alias=cpu=@config//:cpu

# Always pass this env variable to test rules, because SYCL
# OpenCL backend uses it to determine available devices
test --test_env=OCL_ICD_FILENAMES
test --test_env=OCL_ICD_FILENAMES \
--test_env=DAAL_DATASETS

# Configuration: 'host'
# Build & run all host tests
Expand Down
19 changes: 6 additions & 13 deletions cpp/oneapi/dal/algo/pca/test/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,14 @@ class pca_batch_test : public te::algo_fixture {

void check_eigenvalues_order(const table& eigenvalues) const {
const auto W = la::matrix<double>::wrap(eigenvalues);
bool is_descinding = true;
bool is_descending = true;
la::enumerate_linear(W, [&](std::int64_t i, double) {
if (i > 0) {
CAPTURE(i, W.get(i - 1), W.get(i));
is_descinding = is_descinding && (W.get(i - 1) >= W.get(i));
is_descending = is_descending && (W.get(i - 1) >= W.get(i));
}
});
CHECK(is_descinding);
CHECK(is_descending);
}

void check_eigenvectors_orthogonality(const table& eigenvectors) {
Expand Down Expand Up @@ -217,19 +217,12 @@ TEMPLATE_LIST_TEST_M(pca_batch_test,
pca_types) {
SKIP_IF(this->not_available_on_device());

const std::string higgs = "higgs/dataset/higgs_100t_train.csv";

const te::dataframe data = GENERATE_DATAFRAME(te::dataframe_builder{ higgs });
const std::int64_t component_count = 0;
const te::dataframe data =
GENERATE_DATAFRAME(te::dataframe_builder{ "workloads/higgs/dataset/higgs_100t_train.csv" });

// Homogen floating point type is the same as algorithm's floating point type
const auto data_table_id = this->get_homogen_table_id();

const std::int64_t component_count = GENERATE_COPY(0,
1,
data.get_column_count(),
data.get_column_count() - 1,
data.get_column_count() / 2);

this->general_checks(data, component_count, data_table_id);
}

Expand Down
37 changes: 37 additions & 0 deletions cpp/oneapi/dal/backend/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,45 @@

namespace oneapi::dal::backend {

/// Finds the smallest multiple of `multiple` not smaller than `x`
/// Return `x`, if `x` is already multiple of `multiple`
/// Example: down_multiple(10, 4) == 8
/// Example: down_multiple(10, 5) == 10
template <typename Integer>
inline constexpr Integer down_multiple(Integer x, Integer multiple) {
static_assert(std::is_integral_v<Integer>);
ONEDAL_ASSERT(x > 0);
ONEDAL_ASSERT(multiple > 0);
return (x / multiple) * multiple;
}

/// Finds the smallest multiple of `multiple` larger than `x`.
/// Return `x`, if `x` is already multiple of `multiple`
/// Example: up_multiple(10, 4) == 12
/// Example: up_multiple(10, 5) == 10
template <typename Integer>
inline constexpr Integer up_multiple(Integer x, Integer multiple) {
static_assert(std::is_integral_v<Integer>);
ONEDAL_ASSERT(x > 0);
ONEDAL_ASSERT(multiple > 0);
const Integer y = down_multiple<Integer>(x, multiple);
const Integer z = multiple * Integer((x % multiple) != 0);
ONEDAL_ASSERT_SUM_OVERFLOW(Integer, y, z);
return y + z;
}

#ifdef ONEDAL_DATA_PARALLEL

using event_vector = std::vector<sycl::event>;

/// Creates `nd_range`, where global size is multiple of local size
inline sycl::nd_range<1> make_multiple_nd_range_1d(std::int64_t global_size,
std::int64_t local_size) {
const std::int64_t g = dal::detail::integral_cast<std::size_t>(global_size);
const std::int64_t l = dal::detail::integral_cast<std::size_t>(local_size);
return { up_multiple(g, l), l };
}

#endif

} // namespace oneapi::dal::backend
2 changes: 2 additions & 0 deletions cpp/oneapi/dal/backend/primitives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dal_module(
auto = True,
dal_deps = [
"@onedal//cpp/oneapi/dal:common",
"@onedal//cpp/oneapi/dal:table",
],
)

Expand Down Expand Up @@ -44,6 +45,7 @@ dal_collect_test_suites(
modules = [
"blas",
"reduction",
"stat",
],
tests = [
":common_tests",
Expand Down
6 changes: 3 additions & 3 deletions cpp/oneapi/dal/backend/primitives/blas/gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ template <typename Float, ndorder ao, ndorder bo, ndorder co>
sycl::event gemm(sycl::queue& queue,
const ndview<Float, 2, ao>& a,
const ndview<Float, 2, bo>& b,
const ndview<Float, 2, co>& c,
ndview<Float, 2, co>& c,
Float alpha = Float(1),
Float beta = Float(0),
const event_vector& deps = {});
Expand All @@ -35,9 +35,9 @@ template <typename Float, ndorder ao, ndorder bo, ndorder co>
inline sycl::event gemm(sycl::queue& queue,
const ndview<Float, 2, ao>& a,
const ndview<Float, 2, bo>& b,
const ndview<Float, 2, co>& c,
ndview<Float, 2, co>& c,
const event_vector& deps = {}) {
return gemm(queue, a, b, c, Float(1), Float(0), deps);
return gemm<Float>(queue, a, b, c, Float(1), Float(0), deps);
}

#endif
Expand Down
9 changes: 5 additions & 4 deletions cpp/oneapi/dal/backend/primitives/blas/gemm_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,14 @@ template <typename Float, ndorder ao, ndorder bo, ndorder co>
sycl::event gemm(sycl::queue& queue,
const ndview<Float, 2, ao>& a,
const ndview<Float, 2, bo>& b,
const ndview<Float, 2, co>& c,
ndview<Float, 2, co>& c,
Float alpha,
Float beta,
const event_vector& deps) {
ONEDAL_ASSERT(a.get_dimension(0) == c.get_dimension(0));
ONEDAL_ASSERT(a.get_dimension(1) == b.get_dimension(0));
ONEDAL_ASSERT(b.get_dimension(1) == c.get_dimension(1));
ONEDAL_ASSERT(c.has_mutable_data());

constexpr bool is_c_trans = (co == ndorder::c);
if constexpr (is_c_trans) {
Expand All @@ -55,7 +56,7 @@ sycl::event gemm(sycl::queue& queue,
a.get_data(),
a.get_leading_stride(),
beta,
c.get_data(),
c.get_mutable_data(),
c.get_leading_stride(),
deps);
}
Expand All @@ -72,7 +73,7 @@ sycl::event gemm(sycl::queue& queue,
b.get_data(),
b.get_leading_stride(),
beta,
c.get_data(),
c.get_mutable_data(),
c.get_leading_stride(),
deps);
}
Expand All @@ -82,7 +83,7 @@ sycl::event gemm(sycl::queue& queue,
template ONEDAL_EXPORT sycl::event gemm<F, ao, bo, co>(sycl::queue & queue, \
const ndview<F, 2, ao>& a, \
const ndview<F, 2, bo>& b, \
const ndview<F, 2, co>& c, \
ndview<F, 2, co>& c, \
F alpha, \
F beta, \
const event_vector& deps);
Expand Down
2 changes: 1 addition & 1 deletion cpp/oneapi/dal/backend/primitives/blas/test/gemm_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class gemm_test : public te::policy_fixture {
check_if_initialized();
REQUIRE(mat.get_shape() == ndshape<2>{ m_, n_ });

float_t* mat_ptr = mat.get_data();
const float_t* mat_ptr = mat.get_data();
for (std::int64_t i = 0; i < mat.get_count(); i++) {
if (std::int64_t(mat_ptr[i]) != k_) {
CAPTURE(i, mat_ptr[i]);
Expand Down
149 changes: 149 additions & 0 deletions cpp/oneapi/dal/backend/primitives/loops.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*******************************************************************************
* Copyright 2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#pragma once

#include "oneapi/dal/table/common.hpp"
#include "oneapi/dal/table/row_accessor.hpp"
#include "oneapi/dal/backend/primitives/ndarray.hpp"

namespace oneapi::dal::backend::primitives {

class row_block_info {
public:
row_block_info() : block_index_(0), row_start_index_(0), row_count_(0), column_count_(0) {}

std::int64_t get_start_row_index() const {
return row_start_index_;
}

std::int64_t get_end_row_index() const {
return row_start_index_ + row_count_;
}

range get_row_range() const {
return { get_start_row_index(), get_end_row_index() };
}

std::int64_t get_row_count() const {
return row_count_;
}

std::int64_t get_column_count() const {
return column_count_;
}

ndshape<2> get_shape() const {
return { row_count_, column_count_ };
}

std::int64_t get_block_index() const {
return block_index_;
}

const row_block_info& update(std::int64_t block_index,
std::int64_t row_start_index,
std::int64_t row_count,
std::int64_t column_count) {
block_index_ = block_index;
row_start_index_ = row_start_index;
row_count_ = row_count;
column_count_ = column_count;
return *this;
}

private:
std::int64_t block_index_;
std::int64_t row_start_index_;
std::int64_t row_count_;
std::int64_t column_count_;
};

/// Helper function that simplifies looping over the blocked data.
/// See detailed description below.
template <typename T, typename Body>
inline void for_each_block(const ndview<T, 2>& data,
std::int64_t block_max_row_count,
Body&& body) {
ONEDAL_ASSERT(data.has_data());
ONEDAL_ASSERT(block_max_row_count > 0);

for_each_block(data.get_shape(0),
data.get_shape(1),
block_max_row_count,
std::forward<Body>(body));
}

/// Helper function that simplifies looping over the blocked data.
/// See detailed description below.
template <typename Body>
inline void for_each_block(std::int64_t row_count,
std::int64_t column_count,
std::int64_t block_max_row_count,
Body&& body) {
ONEDAL_ASSERT(row_count > 0);
ONEDAL_ASSERT(column_count > 0);
ONEDAL_ASSERT(block_max_row_count > 0);

const std::int64_t block_count = row_count / block_max_row_count;
const std::int64_t tail_block_row_count = row_count % block_max_row_count;

row_block_info info;

for (std::int64_t i = 0; i < block_count; i++) {
body(info.update(i, i * block_max_row_count, block_max_row_count, column_count));
}

if (tail_block_row_count > 0) {
const std::int64_t i = block_count;
body(info.update(i, i * block_max_row_count, tail_block_row_count, column_count));
}
}

/// Helper function that simplifies looping over the blocked data
///
/// Example of recommended usage:
/// @code
/// array<T> block_flat;
/// const auto acc = row_accessor<const T>{ x };
/// const std::int64_t block_row_count = 2048;
///
/// for_each_block(x, block_row_count, [&](const row_block_info& bi) mutable {
/// const T* block_ptr = acc.pull(queue, block_flat, bi.get_range());
/// const auto block = ndview<T, 2>::wrap(block_ptr, bi.get_shape());
/// });
/// @endcode
///
/// @tparam Body The user's block handler, must be a functor that accepts `row_block_info`
///
/// @param data The data needs to be blocked
/// @param block_max_row_count The maximal row count in each block. `body` is not
/// guarantied to be called with the provided `block_max_row_count`.
/// The "tail" block (if data row count is not mutiple of
/// `block_max_row_count`) always contains less rows.
/// @param body The user-provided lambda
template <typename Body>
inline void for_each_block(const table& data, std::int64_t block_max_row_count, Body&& body) {
ONEDAL_ASSERT(data.has_data());
ONEDAL_ASSERT(block_max_row_count > 0);

for_each_block(data.get_row_count(),
data.get_column_count(),
block_max_row_count,
std::forward<Body>(body));
}

} // namespace oneapi::dal::backend::primitives
Loading

0 comments on commit 6e5376c

Please sign in to comment.