Skip to content

Add complex group algorithms #7120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libclc/generic/include/spirv/spirv_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,16 @@ enum GroupOperation {
ExclusiveScan = 2,
};

typedef struct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add bfloat16 here? It is still 'experimental', but will be moved out soon. We can wait for that to happen. Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has already got rather long in the tooth, so I'd like to defer that change for later

float real, imag;
} complex_float;

typedef struct {
double real, imag;
} complex_double;

typedef struct {
half real, imag;
} complex_half;

#endif // CLC_SPIRV_TYPES
253 changes: 210 additions & 43 deletions libclc/ptx-nvidiacl/libspirv/group/collectives.cl

Large diffs are not rendered by default.

58 changes: 47 additions & 11 deletions libclc/ptx-nvidiacl/libspirv/group/collectives_helpers.ll
Original file line number Diff line number Diff line change
@@ -1,61 +1,97 @@
; 64 storage locations is sufficient for all current-generation NVIDIA GPUs
; 64 bits per warp is sufficient for all fundamental data types
; 32 storage locations is sufficient for all current-generation NVIDIA GPUs
; 128 bits per warp is sufficient for all fundamental data types and complex
; Reducing storage for small data types or increasing it for user-defined types
; will likely require an additional pass to track group algorithm usage
@__clc__group_scratch = internal addrspace(3) global [64 x i64] undef, align 1
@__clc__group_scratch = internal addrspace(3) global [128 x i64] undef, align 1

define i8 addrspace(3)* @__clc__get_group_scratch_bool() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
}

define i8 addrspace(3)* @__clc__get_group_scratch_char() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i8 addrspace(3)*
ret i8 addrspace(3)* %cast
}

define i16 addrspace(3)* @__clc__get_group_scratch_short() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i16 addrspace(3)*
ret i16 addrspace(3)* %cast
}

define i32 addrspace(3)* @__clc__get_group_scratch_int() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i32 addrspace(3)*
ret i32 addrspace(3)* %cast
}

define i64 addrspace(3)* @__clc__get_group_scratch_long() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to i64 addrspace(3)*
ret i64 addrspace(3)* %cast
}

define half addrspace(3)* @__clc__get_group_scratch_half() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to half addrspace(3)*
ret half addrspace(3)* %cast
}

define float addrspace(3)* @__clc__get_group_scratch_float() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to float addrspace(3)*
ret float addrspace(3)* %cast
}

define double addrspace(3)* @__clc__get_group_scratch_double() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [64 x i64], [64 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to double addrspace(3)*
ret double addrspace(3)* %cast
}

%complex_half = type {
half,
half
}

%complex_float = type {
float,
float
}

%complex_double = type {
double,
double
}

define %complex_half addrspace(3)* @__clc__get_group_scratch_complex_half() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_half addrspace(3)*
ret %complex_half addrspace(3)* %cast
}

define %complex_float addrspace(3)* @__clc__get_group_scratch_complex_float() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_float addrspace(3)*
ret %complex_float addrspace(3)* %cast
}

define %complex_double addrspace(3)* @__clc__get_group_scratch_complex_double() nounwind alwaysinline {
entry:
%ptr = getelementptr inbounds [128 x i64], [128 x i64] addrspace(3)* @__clc__group_scratch, i64 0, i64 0
%cast = bitcast i64 addrspace(3)* %ptr to %complex_double addrspace(3)*
ret %complex_double addrspace(3)* %cast
}
14 changes: 14 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once
#include <CL/__spirv/spirv_types.hpp>
#include <complex>
#include <cstddef>
#include <cstdint>
#include <sycl/detail/defines.hpp>
Expand Down Expand Up @@ -1051,6 +1052,19 @@ __CLC_BF16_SCAL_VEC(uint32_t)
#undef __CLC_BF16_SCAL_VEC
#undef __CLC_BF16

__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT __spv::complex_half
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
__spv::complex_half) noexcept;
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT __spv::complex_float
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
__spv::complex_float) noexcept;
__SYCL_CONVERGENT__ extern __DPCPP_SYCL_EXTERNAL
__SYCL_EXPORT __spv::complex_double
__spirv_GroupCMulINTEL(unsigned int, unsigned int,
__spv::complex_double) noexcept;

extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInGlobalHWThreadIDINTEL();
extern __DPCPP_SYCL_EXTERNAL int32_t __spirv_BuiltInSubDeviceIDINTEL();

Expand Down
24 changes: 24 additions & 0 deletions sycl/include/CL/__spirv/spirv_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

#pragma once

#include "sycl/half_type.hpp"
#include <sycl/detail/defines.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/half_type.hpp>

#include <complex>
#include <cstddef>
#include <cstdint>

Expand Down Expand Up @@ -128,6 +131,27 @@ enum class MatrixLayout : uint32_t {

enum class MatrixUse : uint32_t { MatrixA = 0, MatrixB = 1, Accumulator = 2 };

struct complex_float {
complex_float() = default;
complex_float(std::complex<float> x) : real(x.real()), imag(x.imag()) {}
operator std::complex<float>() { return {real, imag}; }
float real, imag;
};

struct complex_double {
complex_double() = default;
complex_double(std::complex<double> x) : real(x.real()), imag(x.imag()) {}
operator std::complex<double>() { return {real, imag}; }
double real, imag;
};

struct complex_half {
complex_half() = default;
complex_half(std::complex<sycl::half> x) : real(x.real()), imag(x.imag()) {}
operator std::complex<sycl::half>() { return {real, imag}; }
sycl::half real, imag;
};

#if (SYCL_EXT_ONEAPI_MATRIX_VERSION > 1)
template <typename T, std::size_t R, std::size_t C, MatrixLayout L,
Scope::Flag S = Scope::Flag::Subgroup,
Expand Down
23 changes: 17 additions & 6 deletions sycl/include/sycl/detail/generic_type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <CL/__spirv/spirv_types.hpp>
#include <sycl/access/access.hpp>
#include <sycl/aliases.hpp>
#include <sycl/detail/common.hpp>
Expand All @@ -16,6 +17,7 @@
#include <sycl/half_type.hpp>
#include <sycl/multi_ptr.hpp>

#include <complex>
#include <limits>

namespace sycl {
Expand Down Expand Up @@ -444,6 +446,14 @@ using select_cl_scalar_float_t =
select_apply_cl_scalar_t<T, std::false_type, sycl::opencl::cl_half,
sycl::opencl::cl_float, sycl::opencl::cl_double>;

template <typename T>
using select_cl_scalar_complex_or_T_t = std::conditional_t<
std::is_same<T, std::complex<float>>::value, __spv::complex_float,
std::conditional_t<
std::is_same<T, std::complex<double>>::value, __spv::complex_double,
std::conditional_t<std::is_same<T, std::complex<half>>::value,
__spv::complex_half, T>>>;

template <typename T>
using select_cl_scalar_integral_t =
conditional_t<std::is_signed<T>::value,
Expand All @@ -455,12 +465,13 @@ using select_cl_scalar_integral_t =
template <typename T>
using select_cl_scalar_t = conditional_t<
std::is_integral<T>::value, select_cl_scalar_integral_t<T>,
conditional_t<
std::is_floating_point<T>::value, select_cl_scalar_float_t<T>,
// half is a special case: it is implemented differently on host and
// device and therefore, might lower to different types
conditional_t<std::is_same<T, half>::value,
sycl::detail::half_impl::BIsRepresentationT, T>>>;
conditional_t<std::is_floating_point<T>::value, select_cl_scalar_float_t<T>,
// half is a special case: it is implemented differently on
// host and device and therefore, might lower to different
// types
conditional_t<std::is_same<T, half>::value,
sycl::detail::half_impl::BIsRepresentationT,
select_cl_scalar_complex_or_T_t<T>>>>;

// select_cl_vector_or_scalar_or_ptr does cl_* type selection for element type
// of a vector type T, pointer type substitution, and scalar type substitution.
Expand Down
12 changes: 12 additions & 0 deletions sycl/include/sycl/ext/oneapi/functional.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

#pragma once
#include <sycl/functional.hpp>
#include <sycl/half_type.hpp>

#include <complex>
#include <functional>

namespace sycl {
Expand All @@ -31,6 +33,7 @@ namespace detail {
struct GroupOpISigned {};
struct GroupOpIUnsigned {};
struct GroupOpFP {};
struct GroupOpC {};

template <typename T, typename = void> struct GroupOpTag;

Expand All @@ -49,6 +52,14 @@ struct GroupOpTag<T, detail::enable_if_t<detail::is_sgenfloat<T>::value>> {
using type = GroupOpFP;
};

template <typename T>
struct GroupOpTag<
T, detail::enable_if_t<std::is_same<T, std::complex<half>>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value>> {
using type = GroupOpC;
};

#define __SYCL_CALC_OVERLOAD(GroupTag, SPIRVOperation, BinaryOperation) \
template <typename T, __spv::GroupOperation O, __spv::Scope::Flag S> \
static T calc(GroupTag, T x, BinaryOperation) { \
Expand Down Expand Up @@ -83,6 +94,7 @@ __SYCL_CALC_OVERLOAD(GroupOpFP, FAdd, sycl::plus<T>)
__SYCL_CALC_OVERLOAD(GroupOpISigned, IMulKHR, sycl::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, IMulKHR, sycl::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpFP, FMulKHR, sycl::multiplies<T>)
__SYCL_CALC_OVERLOAD(GroupOpC, CMulINTEL, sycl::multiplies<T>)

__SYCL_CALC_OVERLOAD(GroupOpISigned, BitwiseOrKHR, sycl::bit_or<T>)
__SYCL_CALC_OVERLOAD(GroupOpIUnsigned, BitwiseOrKHR, sycl::bit_or<T>)
Expand Down
Loading