diff --git a/include/interface/blas3_interface.hpp b/include/interface/blas3_interface.hpp index c03668618..8aec2e5b7 100644 --- a/include/interface/blas3_interface.hpp +++ b/include/interface/blas3_interface.hpp @@ -67,19 +67,19 @@ typename Executor::Return_Type _select_gemm( T(_alpha), T(_beta)); \ ret = ex.gemm_executor(gemm); \ } else { \ - auto gemm = make_gemm_no_local_mem( \ + auto gemm = make_gemm_no_local_mem( \ buffer_a, buffer_b, buffer_c, T(_alpha), T(_beta)); \ ret = ex.gemm_executor(gemm); \ } \ return ret; \ } #else -#define ENABLE_GEMM_TRANSPOSE(_trans_a, _trans_b) \ - if (_TransA == _trans_a && _TransB == _trans_b) { \ - auto gemm = make_gemm_no_local_mem( \ - buffer_a, buffer_b, buffer_c, T(_alpha), T(_beta)); \ - ret = ex.gemm_executor(gemm); \ - return ret; \ +#define ENABLE_GEMM_TRANSPOSE(_trans_a, _trans_b) \ + if (_TransA == _trans_a && _TransB == _trans_b) { \ + auto gemm = make_gemm_reference( \ + buffer_a, buffer_b, buffer_c, T(_alpha), T(_beta)); \ + ret = ex.gemm_executor(gemm); \ + return ret; \ } #endif const bool NoTrans = false; @@ -92,7 +92,7 @@ typename Executor::Return_Type _select_gemm( #undef ENABLE_GEMM_TRANSPOSE return ret; -} +} // namespace blas /*! * @brief This is a top-level wrapper for GemmFactory, which provides a @@ -171,7 +171,7 @@ cl::sycl::event _gemm(Executor& ex, char _TransA, char _TransB, IndexType _M, #elif defined(INTEL_GPU) BIND_DATA_SIZE(1024, 4096, 1024) TO_TPARAMS(128, false, 64, 4, 4, 16, 16); BIND_DATA_SIZE(10, 1024, 1024) TO_TPARAMS(128, false, 64, 2, 2, 8, 8); - BIND_DEFAULT TO_TPARAMS(128, false, 64, 8, 8, 8, 8); + BIND_DEFAULT TO_TPARAMS(128, false, 64, 8, 8, 16, 16); #elif defined(RCAR) if (_M < 512 && _N < 512) { BIND_DEFAULT TO_TPARAMS(32, false, 128, 4, 8, 8, 4); diff --git a/include/operations/blas3_trees.hpp b/include/operations/blas3_trees.hpp index 396d12077..f6137e3cb 100644 --- a/include/operations/blas3_trees.hpp +++ b/include/operations/blas3_trees.hpp @@ -49,7 +49,7 @@ ENABLE_TYPE_STRING(double) #undef ENABLE_TYPE_STRING /*! - * @brief This factory generates reference gemm implementations. + * @brief This factory generates reference GEMM implementations. * * These implementations use a naive approach of mapping one value of the * output matrix to each work item, and are highly memory bound. @@ -144,6 +144,7 @@ class ReferenceGemmFactory { C[0] = alpha * reg_res + beta * C[0]; } + void bind(cl::sycl::handler &h) { _A.bind(h); _B.bind(h); @@ -169,6 +170,322 @@ inline bool do_check(bool) { return true; } +/*! + * @brief NoLocalGemmFactory is a template class whose instantiations provide + * different implementations of the GEMM kernel where the is no + * local memory available on the device. + * + * To use the function, each item of a kernel launched with nd_range given by + * NoLocalGemmFactory::get_nd_range() should call eval(). + * + * @tparam ClSize the size of the cache line of the architecture + * This parameter has been reserved for further optimisation + * (If the value passed in is smaller than the actual cache + * line size, some values fetched will be wasted, which can + * significantly reduce performance. It can be set to a + * multiple of the physical cache line size. In this case, it + * will significantly increase local memory usage, but + * will result in fewer local barriers.) + * @tparam TileType determines the size of the local, work group, and top + * level tiles to use, see Tile + * @tparam TransA iff true, matrix A will be transposed on the fly + * @tparam TransB iff true, matrix B will be transposed on the fly + * @tparam T type of matrix elements + */ + +template +class NoLocalGemmFactory { + public: + using value_type = T; + using IndexType = typename RHS0::IndexType; + static constexpr int version = 3; + static constexpr int scratch_size = 0; + + /*! @brief The number of rows processed by each work item */ + static constexpr IndexType item_rows = tile_type::item_rows; + /*! @brief The number of cols processed by each work item */ + static constexpr IndexType item_cols = tile_type::item_cols; + /*! @brief The number of work items in each row of work group */ + static constexpr IndexType wg_rows = tile_type::wg_rows; + /*! @brief The number of work items in each column of work group */ + static constexpr IndexType wg_cols = tile_type::wg_cols; + /*! @brief Number of rows within a work-group level tile */ + static constexpr IndexType block_rows = wg_rows * item_rows; + /*! @brief Number of columns within a work-group level tile */ + static constexpr IndexType block_cols = wg_cols * item_cols; + /*! @brief The size of tile processed by a work-group */ + static constexpr IndexType tile_size = block_rows * block_cols; + /*! @brief A boolean parameter represents wheather or not matrix A is + * transposed */ + static constexpr bool trans_a = TransA; + /*! @brief A boolean parameter represents wheather or not matrix B is + * transposed */ + static constexpr bool trans_b = TransB; + /*! @brief The device cacheline size */ + static constexpr IndexType cl_size = ClSize; + /*! @brief Number of elements which fit within a cache line. */ + static constexpr IndexType cl_elems = cl_size / sizeof(T); + /*! @brief Number of work items within a work group. */ + static constexpr IndexType wg_size = wg_rows * wg_cols; + + static_assert(wg_cols * item_cols == item_rows * wg_rows, + "Work group size should be a multiple " + "of the number of rows in a block\n" + " --- this is ensured iff: item_rows | wg_cols"); + + RHS0 _A; + RHS0 _B; + RHS1 _C; + T alpha; + T beta; + IndexType m; + IndexType n; + IndexType k; + IndexType lda; + IndexType ldb; + IndexType ldc; + + inline NoLocalGemmFactory(RHS0 A, RHS0 B, RHS1 C, T alpha, T beta) + : _A(A), + _B(B), + _C(C), + alpha(alpha), + beta(beta), + m(_A.getSizeR()), + n(_B.getSizeC()), + k(_A.getSizeC()), + lda(_A.getSizeL()), + ldb(_B.getSizeL()), + ldc(_C.getSizeL()) {} + + static inline std::string get_type_string() noexcept { + return std::string("NoLocalGemmFactory<") + std::to_string(wg_size) + ", " + + type_string::get_value() + ">"; + } + + static inline cl::sycl::nd_range<1> get_nd_range(IndexType m, + IndexType n) noexcept { + const cl::sycl::range<1> nwg(((m - 1) / (item_rows * wg_rows) + 1) * + ((n - 1) / (item_cols * wg_cols) + 1)); + const cl::sycl::range<1> wgs(wg_size); + + return cl::sycl::nd_range<1>(nwg * wgs, wgs); + } + + inline IndexType getSize() const { return m * n; } + + inline bool valid_thread(cl::sycl::nd_item<1> ndItem) const { + return ((ndItem.get_global_id(0) < getSize())); + } + + inline void eval(cl::sycl::nd_item<1> id) noexcept { + auto A = _A.getData().get_pointer().get(); + auto B = _B.getData().get_pointer().get(); + auto C = _C.getData().get_pointer().get(); + const auto number_of_block_per_row = ((m - 1) / block_rows) + 1; + + /* linear work group id */ + const auto wg_id = id.get_group(0); + /*linear work item id*/ + const auto item_id = id.get_local_id(0); + /* row tile id per work group */ + const auto tile_id_row = wg_id % number_of_block_per_row; + /* column tile id per work group */ + const auto tile_id_col = wg_id / number_of_block_per_row; + /* work item id per row */ + const auto local_item_id_row = item_id % wg_rows; + /* work item id per column */ + const auto local_item_id_col = item_id / wg_rows; + /* the start position of the tile-row per work group */ + const auto wg_row = tile_id_row * block_rows; + /* the start position of the tile-column per work group */ + const auto wg_col = tile_id_col * block_cols; + /* 2D register array used to store the result C*/ + value_type reg_res[item_rows][item_cols] = {}; + /* temporary register array used to prefetch columns of A*/ + value_type reg_a[item_rows]; + /* temporary register used to prefetch elements of B*/ + value_type reg_b[item_cols]; + + /* Exiting from any threads outside of the m and n boundary */ + if ((local_item_id_row + wg_row >= m) || + (local_item_id_col + wg_col >= n)) { + return; + } + /* + * The ma and na are used to adjust the start position of each work-item for + * A, B and C matrices. + */ + int dim_m_a_start = (local_item_id_row + wg_row); + int dim_n_b_start = (local_item_id_col + wg_col); + + /*! @brief Adjusting the start position of A, B , and C */ + A = A + dim_m_a_start * (trans_a ? lda : 1); + B = B + dim_n_b_start * (trans_b ? 1 : ldb); + C = C + dim_m_a_start + (dim_n_b_start * ldc); + + /*! + * @brief is_internal_block_m and is_internal_block_n is used to distinguish + * the internal block. Therefore, work items using these blocks dont need to + * check for boundaries. + */ + const bool is_internal_block_m = (m - wg_row >= block_rows); + const bool is_internal_block_n = (n - wg_col >= block_cols); + + /* + * The following lambdas: boundary_check_m, boundary_check_n, and + * boundary_check_c are used to check the A, B , and C boundaries + * respectively. + */ + auto boundary_check_m = [&](int dim_m_a_start) { + return dim_m_a_start < m; + }; + auto boundary_check_n = [&](int dim_n_b_start) { + return dim_n_b_start < n; + }; + auto boundary_check_c = [&](int dim_m_c_start, int dim_n_c_start) { + return (dim_m_c_start < m && dim_n_c_start < n); + }; + + /* + * computing the gemm block + */ + while (k > 0) { + /* + * Loading a corresponding block of matrix A into reg_a + */ + (is_internal_block_m) + ? load(A, reg_a, (trans_a ? lda : 1), + dim_m_a_start, boundary_check_m) + : load(A, reg_a, (trans_a ? lda : 1), + dim_m_a_start, boundary_check_m); + /* + * Loading a corresponding block of matrix B into reg_b + */ + (is_internal_block_n) + ? load(B, reg_b, (trans_b ? 1 : ldb), + dim_n_b_start, boundary_check_n) + : load(B, reg_b, (trans_b ? 1 : ldb), + dim_n_b_start, boundary_check_n); + + /* + * Computing a the partial GEMM for the loaded block of reg_a andd + * reg_b and adding the result into reg_res + */ + compute_block_gemm_no_shared(reg_a, reg_b, reg_res); + /* + * Moving forward to the next block + */ + --k; + A = A + (trans_a ? 1 : lda); + B = B + (trans_b ? ldb : 1); + } + /* + * Storing the reg_res into C matrix + */ + (is_internal_block_m && is_internal_block_n) + ? store(C, reg_res, alpha, beta, ldc, dim_m_a_start, + dim_n_b_start, boundary_check_c) + : store(C, reg_res, alpha, beta, ldc, dim_m_a_start, + dim_n_b_start, boundary_check_c); + } + /*! + * @brief binding the placeholder accessors to the SYCL command group handler + * @param h: SYCL command group handler. */ + void bind(cl::sycl::handler &h) { + _A.bind(h); + _B.bind(h); + _C.bind(h); + } + + private: + /*! + * @brief Following function load a block of row_items/col_items elements from + * A/B matrix into reg_a/reg_b. + * @tparam item_size it is the size of private register: either row_items or + * column_item + * @tparam next_element : is the stride to acces the next element of A or B + * matrix. it is either wg_rows or wg_cols. + * @tparam check_block: determined whether or not the requested block is + * internal. false means no need to check the boundaries + * @tparam pointerType: the type of the input matrix + * @tparam check_boundary: the type of a function used for checking the + * boundary for blocks of data located at the edge of the input matrix + * @param ptr : the input matrix, either A or B. + * @param reg[item_size] the private array containing the input block per + * work-item: it is either reg_a or reg_b. + * @param ld : the leading dimension of the input matrix. + * @param index: the start position of the block of data to be loaded. + * @param chk_boundary: an instance of the check_boundary function + */ + + template + static inline void load(PointerType ptr, T (®)[item_size], + const IndexType &ld, int index, + const check_boundary &chk_boundary) noexcept { +#pragma unroll + for (int i = 0; i < item_size; i++) { + reg[i] = do_check(chk_boundary(index)) ? ptr[0] : T(0); + ptr = ptr + (next_element * ld); + index = index + next_element; + } + } + + /*! + * @brief The following function compute the partial GEMM for the input block + * reg_a and reg_b and add the result to the reg_res + * @param reg_a temporary register array used to prefetch columns of A + * @param reg_b temporary register used to prefetch elements of B + * @param reg_res 2D register array used to store the result C + */ + static inline void compute_block_gemm_no_shared( + T (®_a)[item_rows], T (®_b)[item_cols], + T (®_res)[item_rows][item_cols]) noexcept { +#pragma unroll + for (int j = 0; j < item_cols; j++) { +#pragma unroll + for (int i = 0; i < item_rows; i++) { + reg_res[i][j] = reg_res[i][j] + reg_a[i] * reg_b[j]; + } + } + } + + /*! + * @brief For each work itemThe following function store the computed block of + * GEMM reg_res into output matrix C + * @tparam check_block: determined whether or not the requested block is + * internal. false means no need to check the boundaries + * @tparam pointerType: the type of the matrix C + * @tparam check_boundary: the type of a function used for checking the + * boundary for blocks of data located at the edge of the input matrix + * @param c: is the output matrix C + * @param reg_res 2D register array used to store the result C + * @param chk_boundary: an instance of the check_boundary function + * @param alpha and beta are scalars used in GEMM computation + * @param ldc is the leading dimension of C + * @param mc and nc are indices, used to check the boundary of C + */ + template + static inline void store(PointerType C, T (®_res)[item_rows][item_cols], + const T &alpha, const T &beta, const IndexType &ldc, + int dim_m_c_start, int dim_n_c_start, + const check_boundary &chk_boundary) noexcept { +#pragma unroll + for (int j = 0; j < item_cols; j++) { +#pragma unroll + for (int i = 0; i < item_rows; i++) { + if (do_check(chk_boundary(dim_m_c_start + i * wg_rows, + dim_n_c_start + j * wg_cols))) { + C[i * wg_rows] = alpha * reg_res[i][j] + beta * C[i * wg_rows]; + } + } + C = C + (wg_cols * ldc); + } + } +}; // end class NoLocalGemmFactory + /*! * @brief The Tile structure determines the tiling configuration of a gemm * implementation. @@ -449,10 +766,6 @@ class GemmFactory { const auto wg_row = (tile_row + tile_local_id % tl_rows) * block_rows; const auto wg_col = (tile_col + tile_local_id / tl_rows) * block_rows; - /* printf(" g_id %ld, tile_size %ld, tile_id %ld, tile_local_id %ld, - tiles_per_col %ld, tile_row %ld, tile_col %ld, wg_row %ld, wg_col %ld\n", - id.get_global_id(0), tile_size, tile_id, tile_local_id, tiles_per_col, - tile_row, tile_col, wg_row, wg_col );*/ if (wg_row >= m || wg_col >= n) { return; } @@ -752,11 +1065,20 @@ make_gemm(RHS1 buffer_a, RHS1 buffer_b, RHS2 buffer_c, T alpha, T beta) { alpha, beta); } +template +inline NoLocalGemmFactory +make_gemm_no_local_mem(RHS1 buffer_a, RHS1 buffer_b, RHS2 buffer_c, T alpha, + T beta) { + return NoLocalGemmFactory( + buffer_a, buffer_b, buffer_c, alpha, beta); +} + template inline ReferenceGemmFactory -make_gemm_no_local_mem(RHS1 buffer_a, RHS1 buffer_b, RHS2 buffer_c, T alpha, - T beta) { +make_gemm_reference(RHS1 buffer_a, RHS1 buffer_b, RHS2 buffer_c, T alpha, + T beta) { return ReferenceGemmFactory( buffer_a, buffer_b, buffer_c, alpha, beta); }