Skip to content

Commit c546762

Browse files
committed
Initial partially working nvptx ballot_group algs.
Signed-off-by: JackAKirk <jack.kirk@codeplay.com> cluster/ballot/opportunistic_group cuda support. Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
1 parent 56e05ce commit c546762

File tree

10 files changed

+73
-8
lines changed

10 files changed

+73
-8
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#pragma push_macro("PTX42")
4343
#pragma push_macro("PTX60")
4444
#pragma push_macro("PTX61")
45+
#pragma push_macro("PTX62")
4546
#pragma push_macro("PTX63")
4647
#pragma push_macro("PTX64")
4748
#pragma push_macro("PTX65")
@@ -66,7 +67,8 @@
6667
#define PTX65 "ptx65|" PTX70
6768
#define PTX64 "ptx64|" PTX65
6869
#define PTX63 "ptx63|" PTX64
69-
#define PTX61 "ptx61|" PTX63
70+
#define PTX62 "ptx62|" PTX63
71+
#define PTX61 "ptx61|" PTX62
7072
#define PTX60 "ptx60|" PTX61
7173
#define PTX42 "ptx42|" PTX60
7274

@@ -594,6 +596,9 @@ TARGET_BUILTIN(__nvvm_vote_any_sync, "bUib", "", PTX60)
594596
TARGET_BUILTIN(__nvvm_vote_uni_sync, "bUib", "", PTX60)
595597
TARGET_BUILTIN(__nvvm_vote_ballot_sync, "UiUib", "", PTX60)
596598

599+
// Activemask
600+
TARGET_BUILTIN(__nvvm_activemask, "Ui", "", PTX62)
601+
597602
// Match
598603
TARGET_BUILTIN(__nvvm_match_any_sync_i32, "UiUiUi", "", AND(SM_70,PTX60))
599604
TARGET_BUILTIN(__nvvm_match_any_sync_i64, "UiUiWi", "", AND(SM_70,PTX60))

libclc/ptx-nvidiacl/libspirv/SOURCES

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ images/image_helpers.ll
9393
images/image.cl
9494
group/collectives_helpers.ll
9595
group/collectives.cl
96-
group/group_ballot.cl
96+
group/group_non_uniform.cl
9797
atomic/atomic_add.cl
9898
atomic/atomic_and.cl
9999
atomic/atomic_cmpxchg.cl

libclc/ptx-nvidiacl/libspirv/group/group_ballot.cl renamed to libclc/ptx-nvidiacl/libspirv/group/group_non_uniform.cl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "membermask.h"
10+
#include <integer/popcount.h>
1011

1112
#include <spirv/spirv.h>
1213
#include <spirv/spirv_types.h>
@@ -30,7 +31,12 @@ _Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) {
3031
unsigned threads = __clc__membermask();
3132

3233
// run the ballot operation
33-
res[0] = __nvvm_vote_ballot_sync(threads, predicate);
34+
res[0] = __nvvm_vote_ballot_sync(threads, predicate); // couldnt call this within intel impl because undefined behaviour if not all reach it?
3435

3536
return res;
3637
}
38+
39+
_CLC_DEF _CLC_CONVERGENT uint _Z37__spirv_GroupNonUniformBallotBitCountN5__spv5Scope4FlagEiDv4_j(uint scope, uint flag, __clc_vec4_uint32_t mask) {
40+
41+
return __clc_native_popcount(__nvvm_read_ptx_sreg_lanemask_lt() & mask[0]);
42+
}

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4628,6 +4628,11 @@ def int_nvvm_match_all_sync_i64p :
46284628
Intrinsic<[llvm_i32_ty, llvm_i1_ty], [llvm_i32_ty, llvm_i64_ty],
46294629
[IntrInaccessibleMemOnly, IntrConvergent, IntrNoCallback], "llvm.nvvm.match.all.sync.i64p">;
46304630

4631+
// activemask.b32 d;
4632+
def int_nvvm_activemask_ui : ClangBuiltin<"__nvvm_activemask">,
4633+
Intrinsic<[llvm_i32_ty], [],
4634+
[IntrConvergent, IntrInaccessibleMemOnly]>;
4635+
46314636
//
46324637
// REDUX.SYNC
46334638
//

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,13 @@ defm MATCH_ALLP_SYNC_32 : MATCH_ALLP_SYNC<Int32Regs, "b32", int_nvvm_match_all_s
274274
defm MATCH_ALLP_SYNC_64 : MATCH_ALLP_SYNC<Int64Regs, "b64", int_nvvm_match_all_sync_i64p,
275275
i64imm>;
276276

277+
// reqs ptx62 sm_30;
278+
// activemask.b32 d;
279+
def INT_ACTIVEMASK :
280+
NVPTXInst<(outs Int32Regs:$dest), (ins),
281+
"activemask.b32 \t$dest;",
282+
[(set Int32Regs:$dest, (int_nvvm_activemask_ui))]>;
283+
277284
multiclass REDUX_SYNC<string BinOp, string PTXType, Intrinsic Intrin> {
278285
def : NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$mask),
279286
"redux.sync." # BinOp # "." # PTXType # " $dst, $src, $mask;",

sycl/include/sycl/detail/spirv.hpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ template <typename Group> bool GroupAll(Group g, bool pred) {
109109
template <typename ParentGroup>
110110
bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
111111
bool pred) {
112+
#if defined (__SPIR__)
112113
// ballot_group partitions its parent into two groups (0 and 1)
113114
// We have to force each group down different control flow
114115
// Work-items in the "false" group (0) may still be active
@@ -117,6 +118,10 @@ bool GroupAll(ext::oneapi::experimental::ballot_group<ParentGroup> g,
117118
} else {
118119
return __spirv_GroupNonUniformAll(group_scope<ParentGroup>::value, pred);
119120
}
121+
#elif defined (__NVPTX__)
122+
sycl::vec<unsigned, 4> MemberMask = detail::ExtractMask(detail::GetMask(g));
123+
return __nvvm_vote_all_sync(MemberMask[0], pred);
124+
#endif
120125
}
121126

122127
template <typename Group> bool GroupAny(Group g, bool pred) {
@@ -125,6 +130,7 @@ template <typename Group> bool GroupAny(Group g, bool pred) {
125130
template <typename ParentGroup>
126131
bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
127132
bool pred) {
133+
#if defined (__SPIR__)
128134
// ballot_group partitions its parent into two groups (0 and 1)
129135
// We have to force each group down different control flow
130136
// Work-items in the "false" group (0) may still be active
@@ -133,6 +139,10 @@ bool GroupAny(ext::oneapi::experimental::ballot_group<ParentGroup> g,
133139
} else {
134140
return __spirv_GroupNonUniformAny(group_scope<ParentGroup>::value, pred);
135141
}
142+
#elif defined (__NVPTX__)
143+
sycl::vec<unsigned, 4> MemberMask = detail::ExtractMask(detail::GetMask(g));
144+
return __nvvm_vote_any_sync(MemberMask[0], pred);
145+
#endif
136146
}
137147

138148
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
@@ -219,13 +229,18 @@ GroupBroadcast(sycl::ext::oneapi::experimental::ballot_group<ParentGroup> g,
219229
// ballot_group partitions its parent into two groups (0 and 1)
220230
// We have to force each group down different control flow
221231
// Work-items in the "false" group (0) may still be active
232+
#if defined(__SPIR__)
222233
if (g.get_group_id() == 1) {
223234
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
224235
OCLX, OCLId);
225236
} else {
226237
return __spirv_GroupNonUniformBroadcast(group_scope<ParentGroup>::value,
227238
OCLX, OCLId);
228239
}
240+
#elif defined(__NVPTX__)
241+
sycl::vec<unsigned, 4> MemberMask = detail::ExtractMask(detail::GetMask(g));
242+
return __nvvm_shfl_sync_idx_i32(MemberMask[0], x, LocalId, 31); //31 not 32 as docs suggest.
243+
#endif
229244
}
230245

231246
template <typename Group, typename T, typename IdT>
@@ -886,7 +901,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
886901
template <typename Group>
887902
typename std::enable_if_t<
888903
ext::oneapi::experimental::is_user_constructed_group_v<Group>>
889-
ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
904+
ControlBarrier(Group g, memory_scope FenceScope, memory_order Order) {
890905
#if defined(__SPIR__)
891906
// SPIR-V does not define an instruction to synchronize partial groups.
892907
// However, most (possibly all?) of the current SPIR-V targets execute
@@ -899,6 +914,7 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
899914
__spv::MemorySemanticsMask::CrossWorkgroupMemory);
900915
#elif defined(__NVPTX__)
901916
// TODO: Call syncwarp with appropriate mask extracted from the group
917+
__nvvm_bar_warp_sync(detail::ExtractMask(detail::GetMask(g))[0]);
902918
#endif
903919
}
904920

sycl/include/sycl/ext/oneapi/experimental/ballot_group.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ template <typename ParentGroup> class ballot_group {
121121
friend ballot_group<ParentGroup>
122122
get_ballot_group<ParentGroup>(ParentGroup g, bool predicate);
123123

124-
friend uint32_t sycl::detail::IdToMaskPosition<ballot_group<ParentGroup>>(
125-
ballot_group<ParentGroup> Group, uint32_t Id);
124+
friend sub_group_mask sycl::detail::GetMask<ballot_group<ParentGroup>>(ballot_group<ParentGroup> Group);
125+
126126
};
127127

128128
template <typename Group>

sycl/include/sycl/ext/oneapi/experimental/cluster_group.hpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#pragma once
1010

1111
#include <sycl/ext/oneapi/experimental/non_uniform_groups.hpp>
12+
#include <sycl/ext/oneapi/sub_group_mask.hpp>
1213

1314
namespace sycl {
1415
__SYCL_INLINE_VER_NAMESPACE(_V1) {
@@ -111,8 +112,17 @@ template <size_t ClusterSize, typename ParentGroup> class cluster_group {
111112
#endif
112113
}
113114

115+
#if defined (__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
116+
private:
117+
sub_group_mask Mask;
118+
#endif
119+
114120
protected:
121+
#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__)
122+
cluster_group(ext::oneapi::sub_group_mask mask):Mask(mask) {}
123+
#else
115124
cluster_group() {}
125+
#endif
116126

117127
friend cluster_group<ClusterSize, ParentGroup>
118128
get_cluster_group<ClusterSize, ParentGroup>(ParentGroup g);
@@ -125,7 +135,16 @@ inline std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
125135
get_cluster_group(Group group) {
126136
(void)group;
127137
#ifdef __SYCL_DEVICE_ONLY__
138+
#if defined(__NVPTX__)
139+
uint32_t loc_id = group.get_local_linear_id();
140+
uint32_t loc_size = group.get_local_linear_range();
141+
uint32_t bits = (1 << ClusterSize) - 1;
142+
143+
return cluster_group<ClusterSize, sycl::sub_group>(sycl::detail::Builder::createSubGroupMask<ext::oneapi::sub_group_mask>(
144+
bits << ((loc_id / ClusterSize) * ClusterSize), loc_size));
145+
#else
128146
return cluster_group<ClusterSize, sycl::sub_group>();
147+
#endif
129148
#else
130149
throw runtime_error("Non-uniform groups are not supported on host device.",
131150
PI_ERROR_INVALID_DEVICE);

sycl/include/sycl/ext/oneapi/experimental/non_uniform_groups.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,16 @@ inline uint32_t CallerPositionInMask(ext::oneapi::sub_group_mask Mask) {
3939
}
4040
#endif
4141

42+
//todo inline works?
43+
template <typename NonUniformGroup>
44+
inline ext::oneapi::sub_group_mask GetMask(NonUniformGroup Group) {
45+
return Group.Mask;
46+
}
47+
4248
template <typename NonUniformGroup>
4349
inline uint32_t IdToMaskPosition(NonUniformGroup Group, uint32_t Id) {
4450
// TODO: This will need to be optimized
45-
sycl::vec<unsigned, 4> MemberMask = ExtractMask(Group.Mask);
51+
sycl::vec<unsigned, 4> MemberMask = ExtractMask(GetMask(Group));
4652
uint32_t Count = 0;
4753
for (int i = 0; i < 4; ++i) {
4854
for (int b = 0; b < 32; ++b) {

sycl/include/sycl/ext/oneapi/experimental/opportunistic_group.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ inline opportunistic_group get_opportunistic_group() {
130130
sub_group_mask mask = sycl::ext::oneapi::group_ballot(sg, true);
131131
return opportunistic_group(mask);
132132
#elif defined(__NVPTX__)
133-
// TODO: Construct from __activemask
133+
sub_group_mask mask = sycl::detail::Builder::createSubGroupMask<ext::oneapi::sub_group_mask>(__nvvm_activemask(), 32);
134+
return opportunistic_group(mask);
134135
#endif
135136
#else
136137
throw runtime_error("Non-uniform groups are not supported on host device.",

0 commit comments

Comments
 (0)