Skip to content

Commit

Permalink
Fix logloss hessian product primitive private CI error (#2429)
Browse files Browse the repository at this point in the history
* Initial commit

* Add regularization

* Fix compute_without_fit_intercept function

* Add deps to mkl gemv

* Minor
  • Loading branch information
avolkov-intel authored Jun 29, 2023
1 parent 62d0e66 commit cfbb628
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
6 changes: 4 additions & 2 deletions cpp/oneapi/dal/backend/primitives/blas/gemv_dpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ sycl::event gemv(sycl::queue& queue,
std::int64_t(1),
beta,
y.get_mutable_data(),
std::int64_t(1));
std::int64_t(1),
deps);
}
else {
ONEDAL_ASSERT(lda >= m);
Expand All @@ -71,7 +72,8 @@ sycl::event gemv(sycl::queue& queue,
std::int64_t(1),
beta,
y.get_mutable_data(),
std::int64_t(1));
std::int64_t(1),
deps);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -523,14 +523,13 @@ sycl::event logloss_hessian_product<Float>::compute_with_fit_intercept(const ndv
auto out_bias = out.get_slice(0, 1);
auto vec_suf = vec.get_slice(1, p_ + 1);

sycl::event fill_out_event = copy(q_, out_suf, vec_suf, deps);
sycl::event fill_out0_event = fill<Float>(q_, out_bias, Float(0), deps);
sycl::event fill_out_event = fill<Float>(q_, out, Float(0), deps);

Float v0 = vec.at_device(q_, 0, deps);
sycl::event event_xv = gemv(q_, data_, vec_suf, buffer_, Float(1), v0, { fill_buffer_event });

sycl::event event_dxv = q_.submit([&](sycl::handler& cgh) {
cgh.depends_on({ event_xv, fill_out_event, fill_out0_event });
cgh.depends_on({ event_xv, fill_out_event });
const auto range = make_range_1d(n_);
auto sum_reduction = sycl::reduction(out_ptr, sycl::plus<>());
cgh.parallel_for(range, sum_reduction, [=](sycl::id<1> idx, auto& sum_v0) {
Expand All @@ -539,8 +538,17 @@ sycl::event logloss_hessian_product<Float>::compute_with_fit_intercept(const ndv
});
});
auto event_xtdxv =
gemv(q_, data_.t(), buffer_, out_suf, Float(1), L2_ * 2, { event_dxv, fill_out_event });
return event_xtdxv;
gemv(q_, data_.t(), buffer_, out_suf, Float(1), Float(0), { event_dxv, fill_out_event });

const Float regularization_factor = L2_ * 2;

const auto kernel_regularization = [=](const Float a, const Float param) {
return a + param * regularization_factor;
};

auto add_regularization_event =
element_wise(q_, kernel_regularization, out_suf, vec_suf, out_suf, { event_xtdxv });
return add_regularization_event;
}

template <typename Float>
Expand All @@ -551,7 +559,7 @@ sycl::event logloss_hessian_product<Float>::compute_without_fit_intercept(
ONEDAL_ASSERT(vec.get_dimension(0) == p_);
ONEDAL_ASSERT(out.get_dimension(0) == p_);

sycl::event fill_out_event = copy(q_, out, vec, deps);
sycl::event fill_out_event = fill<Float>(q_, out, Float(0), deps);

auto event_xv = gemv(q_, data_, vec, buffer_, Float(1), Float(0), deps);

Expand All @@ -562,8 +570,18 @@ sycl::event logloss_hessian_product<Float>::compute_without_fit_intercept(
element_wise(q_, kernel_mul, buf_ndview, hess_ndview, buf_ndview, { event_xv });

auto event_xtdxv =
gemv(q_, data_.t(), buffer_, out, Float(1), L2_ * 2, { event_dxv, fill_out_event });
return event_xtdxv;
gemv(q_, data_.t(), buffer_, out, Float(1), Float(0), { event_dxv, fill_out_event });

const Float regularization_factor = L2_ * 2;

const auto kernel_regularization = [=](const Float a, const Float param) {
return a + param * regularization_factor;
};

auto add_regularization_event =
element_wise(q_, kernel_regularization, out, vec, out, { event_xtdxv });

return add_regularization_event;
}

template <typename Float>
Expand Down

0 comments on commit cfbb628

Please sign in to comment.