Skip to content

Commit

Permalink
All deterministic tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
anbenali committed Mar 27, 2022
1 parent 831f8e8 commit 173d22e
Show file tree
Hide file tree
Showing 5 changed files with 372 additions and 31 deletions.
163 changes: 162 additions & 1 deletion src/Platforms/OMPTarget/ompBLAS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,23 @@ ompBLAS_status gemv_batched_impl(ompBLAS_handle& handle,
}
else
{
throw std::runtime_error("trans = 'N' not implemented in gemv_batched_impl!");
if (incx !=1 || incy != 1)
throw std::runtime_error("incx !=1 or incy != 1 are not implemented in ompBLAS::gemv_batched_impl!");

PRAGMA_OFFLOAD("omp target teams distribute collapse(2) num_teams(batch_count * n) is_device_ptr(A, x, y, alpha, beta)")
for(size_t ib = 0; ib < batch_count; ib++)
for(size_t i = 0; i < n; i++)
{
T dot_sum(0);
PRAGMA_OFFLOAD("omp parallel for simd reduction(+: dot_sum)")
for(size_t j = 0; j < m; j++)
dot_sum += x[ib][j] * A[ib][j * lda + i];
if (beta[ib] == T(0))
y[ib][i] = alpha[ib] * dot_sum; // protecting NaN from y
else
y[ib][i] = alpha[ib] * dot_sum + beta[ib] * y[ib][i];
}
return 0;
}
}

Expand Down Expand Up @@ -411,5 +427,150 @@ ompBLAS_status ger_batched(ompBLAS_handle& handle,
return ger_batched_impl(handle, m, n, alpha, x, incx, y, incy, A, lda, batch_count);
}
#endif


template<typename T>
ompBLAS_status copy_batched_impl(ompBLAS_handle& handle,
const int n,
const T* const x[],
const int incx,
T* const y[],
const int incy,
const int batch_count)
{
if (batch_count == 0) return 0;

//if (incx !=1 || incy != 1)
// throw std::runtime_error("incx !=1 or incy != 1 are not implemented in ompBLAS::copy_batched_impl!");

PRAGMA_OFFLOAD("omp target teams distribute parallel for collapse(2) is_device_ptr(x, y)")
for (size_t ib = 0; ib < batch_count; ib++)
for (size_t i = 0; i < n; i++)
y[ib][i * incy] = x[ib][i * incx];
return 0;
}

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const float* const x[],
const int incx,
float* const y[],
const int incy,
const int batch_count)
{
return copy_batched_impl(handle, n, x, incx, y, incy, batch_count);
}

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const double* const x[],
const int incx,
double* const y[],
const int incy,
const int batch_count)
{
return copy_batched_impl(handle, n, x, incx, y, incy, batch_count);
}

#if !defined(OPENMP_NO_COMPLEX)
ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const std::complex<float>* const x[],
const int incx,
std::complex<float>* const y[],
const int incy,
const int batch_count)
{
return copy_batched_impl(handle, n, x, incx, y, incy, batch_count);
}

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const std::complex<double>* const x[],
const int incx,
std::complex<double>* const y[],
const int incy,
const int batch_count)
{
return copy_batched_impl(handle, n, x, incx, y, incy, batch_count);
}
#endif

template<typename T>
ompBLAS_status copy_batched_offset_impl(ompBLAS_handle& handle,
const int n,
const T* const x[],
const int x_offset,
const int incx,
T* const y[],
const int y_offset,
const int incy,
const int batch_count)
{
if (batch_count == 0) return 0;

//if (incx !=1 || incy != 1)
// throw std::runtime_error("incx !=1 or incy != 1 are not implemented in ompBLAS::copy_batched_impl!");

PRAGMA_OFFLOAD("omp target teams distribute parallel for collapse(2) is_device_ptr(x, y)")
for (size_t ib = 0; ib < batch_count; ib++)
for (size_t i = 0; i < n; i++)
y[ib][y_offset + i * incy] = x[ib][x_offset + i * incx];
return 0;
}

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const float* const x[],
const int x_offset,
const int incx,
float* const y[],
const int y_offset,
const int incy,
const int batch_count)
{
return copy_batched_offset_impl(handle, n, x, x_offset, incx, y, y_offset, incy, batch_count);
}

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const double* const x[],
const int x_offset,
const int incx,
double* const y[],
const int y_offset,
const int incy,
const int batch_count)
{
return copy_batched_offset_impl(handle, n, x, x_offset, incx, y, y_offset, incy, batch_count);
}

#if !defined(OPENMP_NO_COMPLEX)
ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const std::complex<float>* const x[],
const int x_offset,
const int incx,
std::complex<float>* const y[],
const int y_offset,
const int incy,
const int batch_count)
{
return copy_batched_offset_impl(handle, n, x, x_offset, incx, y, y_offset, incy, batch_count);
}

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const std::complex<double>* const x[],
const int x_offset,
const int incx,
std::complex<double>* const y[],
const int y_offset,
const int incy,
const int batch_count)
{
return copy_batched_offset_impl(handle, n, x, x_offset, incx, y, y_offset, incy, batch_count);
}
#endif
} // namespace ompBLAS
} // namespace qmcplusplus
72 changes: 72 additions & 0 deletions src/Platforms/OMPTarget/ompBLAS.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,78 @@ ompBLAS_status ger_batched(ompBLAS_handle& handle,
const int lda,
const int batch_count);

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const float* const x[],
const int incx,
float* const y[],
const int incy,
const int batch_count);

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const double* const x[],
const int incx,
double* const y[],
const int incy,
const int batch_count);

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const std::complex<float>* const x[],
const int incx,
std::complex<float>* const y[],
const int incy,
const int batch_count);

ompBLAS_status copy_batched(ompBLAS_handle& handle,
const int n,
const std::complex<double>* const x[],
const int incx,
std::complex<double>* const y[],
const int incy,
const int batch_count);

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const float* const x[],
const int x_offset,
const int incx,
float* const y[],
const int y_offset,
const int incy,
const int batch_count);

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const double* const x[],
const int x_offset,
const int incx,
double* const y[],
const int y_offset,
const int incy,
const int batch_count);

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const std::complex<float>* const x[],
const int x_offset,
const int incx,
std::complex<float>* const y[],
const int y_offset,
const int incy,
const int batch_count);

ompBLAS_status copy_batched_offset(ompBLAS_handle& handle,
const int n,
const std::complex<double>* const x[],
const int x_offset,
const int incx,
std::complex<double>* const y[],
const int y_offset,
const int incy,
const int batch_count);

} // namespace ompBLAS

} // namespace qmcplusplus
Expand Down
Loading

0 comments on commit 173d22e

Please sign in to comment.