Skip to content

Commit 2bdc4c4

Browse files
authored
[ESIMD] Add support for an arbitrary number of elements to simd::copy_from/to (#5135)
This patch adds support for simd objects with any number of elements to simd::copy_from/to methods. Signed-off-by: Sergey Dmitriev <serguei.n.dmitriev@intel.com>
1 parent 2e798df commit 2bdc4c4

File tree

3 files changed

+243
-94
lines changed

3 files changed

+243
-94
lines changed

sycl/include/sycl/ext/intel/experimental/esimd/detail/simd_obj_impl.hpp

Lines changed: 208 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ template <typename Ty, int N, class Derived, class SFINAE> class simd_obj_impl {
577577
/// vector_aligned_tag, \p addr must be aligned by simd_obj_impl's vector_type
578578
/// alignment. If Flags is overaligned_tag<N>, \p addr must be aligned by N.
579579
/// Program not meeting alignment requirements results in undefined behavior.
580-
template <typename Flags = element_aligned_tag,
580+
template <typename Flags = element_aligned_tag, int ChunkSize = 32,
581581
typename = std::enable_if_t<is_simd_flag_type_v<Flags>>>
582582
ESIMD_INLINE void copy_from(const Ty *addr, Flags = {}) SYCL_ESIMD_FUNCTION;
583583

@@ -593,6 +593,7 @@ template <typename Ty, int N, class Derived, class SFINAE> class simd_obj_impl {
593593
/// alignment. If Flags is overaligned_tag<N>, offset must be aligned by N.
594594
/// Program not meeting alignment requirements results in undefined behavior.
595595
template <typename AccessorT, typename Flags = element_aligned_tag,
596+
int ChunkSize = 32,
596597
typename = std::enable_if_t<is_simd_flag_type_v<Flags>>>
597598
ESIMD_INLINE EnableIfAccessor<AccessorT, accessor_mode_cap::can_read,
598599
sycl::access::target::global_buffer, void>
@@ -606,7 +607,7 @@ template <typename Ty, int N, class Derived, class SFINAE> class simd_obj_impl {
606607
/// vector_aligned_tag, \p addr must be aligned by simd_obj_impl's vector_type
607608
/// alignment. If Flags is overaligned_tag<N>, \p addr must be aligned by N.
608609
/// Program not meeting alignment requirements results in undefined behavior.
609-
template <typename Flags = element_aligned_tag,
610+
template <typename Flags = element_aligned_tag, int ChunkSize = 32,
610611
typename = std::enable_if_t<is_simd_flag_type_v<Flags>>>
611612
ESIMD_INLINE void copy_to(Ty *addr, Flags = {}) const SYCL_ESIMD_FUNCTION;
612613

@@ -621,6 +622,7 @@ template <typename Ty, int N, class Derived, class SFINAE> class simd_obj_impl {
621622
/// alignment. If Flags is overaligned_tag<N>, offset must be aligned by N.
622623
/// Program not meeting alignment requirements results in undefined behavior.
623624
template <typename AccessorT, typename Flags = element_aligned_tag,
625+
int ChunkSize = 32,
624626
typename = std::enable_if_t<is_simd_flag_type_v<Flags>>>
625627
ESIMD_INLINE EnableIfAccessor<AccessorT, accessor_mode_cap::can_write,
626628
sycl::access::target::global_buffer, void>
@@ -733,144 +735,256 @@ template <typename Ty, int N, class Derived, class SFINAE> class simd_obj_impl {
733735
// ----------- Outlined implementations of simd_obj_impl class APIs.
734736

735737
template <typename T, int N, class T1, class SFINAE>
736-
template <typename Flags, typename>
738+
template <typename Flags, int ChunkSize, typename>
737739
void simd_obj_impl<T, N, T1, SFINAE>::copy_from(const T *Addr,
738740
Flags) SYCL_ESIMD_FUNCTION {
739741
constexpr unsigned Size = sizeof(T) * N;
740742
constexpr unsigned Align = Flags::template alignment<T1>;
741743

742-
simd<T, N> Tmp;
744+
constexpr unsigned BlockSize = OperandSize::OWORD * 8;
745+
constexpr unsigned NumBlocks = Size / BlockSize;
746+
constexpr unsigned RemSize = Size % BlockSize;
747+
743748
if constexpr (Align >= OperandSize::DWORD && Size % OperandSize::OWORD == 0 &&
744-
detail::isPowerOf2(Size / OperandSize::OWORD)) {
745-
Tmp = block_load<T, N, Flags>(Addr, Flags{});
749+
detail::isPowerOf2(RemSize / OperandSize::OWORD)) {
750+
if constexpr (NumBlocks > 0) {
751+
constexpr unsigned BlockN = BlockSize / sizeof(T);
752+
ForHelper<NumBlocks>::unroll([BlockN, Addr, this](unsigned Block) {
753+
select<BlockN, 1>(Block * BlockN) =
754+
block_load<T, BlockN, Flags>(Addr + (Block * BlockN), Flags{});
755+
});
756+
}
757+
if constexpr (RemSize > 0) {
758+
constexpr unsigned RemN = RemSize / sizeof(T);
759+
constexpr unsigned BlockN = BlockSize / sizeof(T);
760+
select<RemN, 1>(NumBlocks * BlockN) =
761+
block_load<T, RemN, Flags>(Addr + (NumBlocks * BlockN), Flags{});
762+
}
746763
} else if constexpr (sizeof(T) == 8) {
747-
constexpr unsigned AlignUH =
748-
(N * 4) % Align == 0 ? Align : std::min(Align, 4u);
749-
simd<int32_t, N> LH(reinterpret_cast<const int32_t *>(Addr), Flags{});
750-
simd<int32_t, N> UH(reinterpret_cast<const int32_t *>(Addr) + N,
751-
overaligned<AlignUH>);
752-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(0) = LH;
753-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(N) = UH;
754-
} else if constexpr (N == 1) {
755-
Tmp = *Addr;
756-
} else if constexpr (N == 8 || N == 16 || N == 32) {
757-
simd<uint32_t, N> Offsets(0u, sizeof(T));
758-
Tmp = gather<T, N>(Addr, Offsets);
764+
simd<int32_t, N * 2> BC(reinterpret_cast<const int32_t *>(Addr), Flags{});
765+
bit_cast_view<int32_t>() = BC;
759766
} else {
760-
constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32;
761-
simd_mask_type<N1> Pred(0);
762-
Pred.template select<N, 1>() = 1;
763-
simd<uint32_t, N1> Offsets(0u, sizeof(T));
764-
simd<T, N1> Vals = gather<T, N1>(Addr, Offsets, Pred);
765-
Tmp = Vals.template select<N, 1>();
766-
}
767-
*this = Tmp.data();
767+
constexpr unsigned NumChunks = N / ChunkSize;
768+
if constexpr (NumChunks > 0) {
769+
simd<uint32_t, ChunkSize> Offsets(0u, sizeof(T));
770+
ForHelper<NumChunks>::unroll([Addr, &Offsets, this](unsigned Block) {
771+
select<ChunkSize, 1>(Block * ChunkSize) =
772+
gather<T, ChunkSize>(Addr + (Block * ChunkSize), Offsets);
773+
});
774+
}
775+
constexpr unsigned RemN = N % ChunkSize;
776+
if constexpr (RemN > 0) {
777+
if constexpr (RemN == 1) {
778+
select<1, 1>(NumChunks * ChunkSize) = Addr[NumChunks * ChunkSize];
779+
} else if constexpr (RemN == 8 || RemN == 16) {
780+
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
781+
select<RemN, 1>(NumChunks * ChunkSize) =
782+
gather<T, RemN>(Addr + (NumChunks * ChunkSize), Offsets);
783+
} else {
784+
constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32;
785+
simd_mask_type<N1> Pred(0);
786+
Pred.template select<RemN, 1>() = 1;
787+
simd<uint32_t, N1> Offsets(0u, sizeof(T));
788+
simd<T, N1> Vals =
789+
gather<T, N1>(Addr + (NumChunks * ChunkSize), Offsets, Pred);
790+
select<RemN, 1>(NumChunks * ChunkSize) =
791+
Vals.template select<RemN, 1>();
792+
}
793+
}
794+
}
768795
}
769796

770797
template <typename T, int N, class T1, class SFINAE>
771-
template <typename AccessorT, typename Flags, typename>
798+
template <typename AccessorT, typename Flags, int ChunkSize, typename>
772799
ESIMD_INLINE EnableIfAccessor<AccessorT, accessor_mode_cap::can_read,
773800
sycl::access::target::global_buffer, void>
774801
simd_obj_impl<T, N, T1, SFINAE>::copy_from(AccessorT acc, uint32_t offset,
775802
Flags) SYCL_ESIMD_FUNCTION {
776803
constexpr unsigned Size = sizeof(T) * N;
777804
constexpr unsigned Align = Flags::template alignment<T1>;
778805

779-
simd<T, N> Tmp;
806+
constexpr unsigned BlockSize = OperandSize::OWORD * 8;
807+
constexpr unsigned NumBlocks = Size / BlockSize;
808+
constexpr unsigned RemSize = Size % BlockSize;
809+
780810
if constexpr (Align >= OperandSize::DWORD && Size % OperandSize::OWORD == 0 &&
781-
detail::isPowerOf2(Size / OperandSize::OWORD)) {
782-
Tmp = block_load<T, N, AccessorT, Flags>(acc, offset, Flags{});
811+
detail::isPowerOf2(RemSize / OperandSize::OWORD)) {
812+
if constexpr (NumBlocks > 0) {
813+
constexpr unsigned BlockN = BlockSize / sizeof(T);
814+
ForHelper<NumBlocks>::unroll([BlockN, acc, offset, this](unsigned Block) {
815+
select<BlockN, 1>(Block * BlockN) =
816+
block_load<T, BlockN, AccessorT, Flags>(
817+
acc, offset + (Block * BlockSize), Flags{});
818+
});
819+
}
820+
if constexpr (RemSize > 0) {
821+
constexpr unsigned RemN = RemSize / sizeof(T);
822+
constexpr unsigned BlockN = BlockSize / sizeof(T);
823+
select<RemN, 1>(NumBlocks * BlockN) =
824+
block_load<T, RemN, AccessorT, Flags>(
825+
acc, offset + (NumBlocks * BlockSize), Flags{});
826+
}
783827
} else if constexpr (sizeof(T) == 8) {
784-
constexpr unsigned AlignUH =
785-
(N * 4) % Align == 0 ? Align : std::min(Align, 4u);
786-
simd<int32_t, N> LH(acc, offset, Flags{});
787-
simd<int32_t, N> UH(acc, offset + N * 4, overaligned<AlignUH>);
788-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(0) = LH;
789-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(N) = UH;
790-
} else if constexpr (N == 1 || N == 8 || N == 16 || N == 32) {
791-
simd<uint32_t, N> Offsets(0u, sizeof(T));
792-
Tmp = gather<T, N, AccessorT>(acc, Offsets, offset);
828+
simd<int32_t, N * 2> BC(acc, offset, Flags{});
829+
bit_cast_view<int32_t>() = BC;
793830
} else {
794-
constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32;
795-
simd_mask_type<N1> Pred(0);
796-
Pred.template select<N, 1>() = 1;
797-
simd<uint32_t, N1> Offsets(0u, sizeof(T));
798-
simd<T, N1> Vals = gather<T, N1>(acc, Offsets, offset, Pred);
799-
Tmp = Vals.template select<N, 1>();
800-
}
801-
*this = Tmp.data();
831+
constexpr unsigned NumChunks = N / ChunkSize;
832+
if constexpr (NumChunks > 0) {
833+
simd<uint32_t, ChunkSize> Offsets(0u, sizeof(T));
834+
ForHelper<NumChunks>::unroll(
835+
[acc, offset, &Offsets, this](unsigned Block) {
836+
select<ChunkSize, 1>(Block * ChunkSize) =
837+
gather<T, ChunkSize, AccessorT>(
838+
acc, Offsets, offset + (Block * ChunkSize * sizeof(T)));
839+
});
840+
}
841+
constexpr unsigned RemN = N % ChunkSize;
842+
if constexpr (RemN > 0) {
843+
if constexpr (RemN == 1 || RemN == 8 || RemN == 16) {
844+
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
845+
select<RemN, 1>(NumChunks * ChunkSize) = gather<T, RemN, AccessorT>(
846+
acc, Offsets, offset + (NumChunks * ChunkSize * sizeof(T)));
847+
} else {
848+
constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32;
849+
simd_mask_type<N1> Pred(0);
850+
Pred.template select<RemN, 1>() = 1;
851+
simd<uint32_t, N1> Offsets(0u, sizeof(T));
852+
simd<T, N1> Vals = gather<T, N1>(
853+
acc, Offsets, offset + (NumChunks * ChunkSize * sizeof(T)), Pred);
854+
select<RemN, 1>(NumChunks * ChunkSize) =
855+
Vals.template select<RemN, 1>();
856+
}
857+
}
858+
}
802859
}
803860

804861
template <typename T, int N, class T1, class SFINAE>
805-
template <typename Flags, typename>
862+
template <typename Flags, int ChunkSize, typename>
806863
void simd_obj_impl<T, N, T1, SFINAE>::copy_to(T *addr,
807864
Flags) const SYCL_ESIMD_FUNCTION {
808865
constexpr unsigned Size = sizeof(T) * N;
809866
constexpr unsigned Align = Flags::template alignment<T1>;
810867

868+
constexpr unsigned BlockSize = OperandSize::OWORD * 8;
869+
constexpr unsigned NumBlocks = Size / BlockSize;
870+
constexpr unsigned RemSize = Size % BlockSize;
871+
872+
simd<T, N> Tmp = data();
811873
if constexpr (Align >= OperandSize::OWORD && Size % OperandSize::OWORD == 0 &&
812-
detail::isPowerOf2(Size / OperandSize::OWORD)) {
813-
block_store<T, N>(addr, cast_this_to_derived());
874+
detail::isPowerOf2(RemSize / OperandSize::OWORD)) {
875+
if constexpr (NumBlocks > 0) {
876+
constexpr unsigned BlockN = BlockSize / sizeof(T);
877+
ForHelper<NumBlocks>::unroll([BlockN, addr, &Tmp](unsigned Block) {
878+
block_store<T, BlockN>(addr + (Block * BlockN),
879+
Tmp.template select<BlockN, 1>(Block * BlockN));
880+
});
881+
}
882+
if constexpr (RemSize > 0) {
883+
constexpr unsigned RemN = RemSize / sizeof(T);
884+
constexpr unsigned BlockN = BlockSize / sizeof(T);
885+
block_store<T, RemN>(addr + (NumBlocks * BlockN),
886+
Tmp.template select<RemN, 1>(NumBlocks * BlockN));
887+
}
814888
} else if constexpr (sizeof(T) == 8) {
815-
constexpr unsigned AlignUH =
816-
(N * 4) % Align == 0 ? Align : std::min(Align, 4u);
817-
simd<T, N> Tmp = data();
818-
simd<int32_t, N> LH =
819-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(0);
820-
simd<int32_t, N> UH =
821-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(N);
822-
LH.copy_to(reinterpret_cast<int32_t *>(addr), Flags{});
823-
UH.copy_to(reinterpret_cast<int32_t *>(addr) + N, overaligned<AlignUH>);
824-
} else if constexpr (N == 1) {
825-
*addr = data()[0];
826-
} else if constexpr (N == 8 || N == 16 || N == 32) {
827-
simd<uint32_t, N> offsets(0u, sizeof(T));
828-
scatter<T, N>(addr, offsets, cast_this_to_derived().data());
889+
simd<int32_t, N * 2> BC = Tmp.template bit_cast_view<int32_t>();
890+
BC.copy_to(reinterpret_cast<int32_t *>(addr), Flags{});
829891
} else {
830-
constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32;
831-
simd_mask_type<N1> pred(0);
832-
pred.template select<N, 1>() = 1;
833-
simd<T, N1> vals(0);
834-
vals.template select<N, 1>() = cast_this_to_derived().data();
835-
simd<uint32_t, N1> offsets(0u, sizeof(T));
836-
scatter<T, N1>(addr, offsets, vals, pred);
892+
constexpr unsigned NumChunks = N / ChunkSize;
893+
if constexpr (NumChunks > 0) {
894+
simd<uint32_t, ChunkSize> Offsets(0u, sizeof(T));
895+
ForHelper<NumChunks>::unroll([addr, &Offsets, &Tmp](unsigned Block) {
896+
scatter<T, ChunkSize>(
897+
addr + (Block * ChunkSize), Offsets,
898+
Tmp.template select<ChunkSize, 1>(Block * ChunkSize));
899+
});
900+
}
901+
constexpr unsigned RemN = N % ChunkSize;
902+
if constexpr (RemN > 0) {
903+
if constexpr (RemN == 1) {
904+
addr[NumChunks * ChunkSize] = Tmp[NumChunks * ChunkSize];
905+
} else if constexpr (RemN == 8 || RemN == 16) {
906+
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
907+
scatter<T, RemN>(addr + (NumChunks * ChunkSize), Offsets,
908+
Tmp.template select<RemN, 1>(NumChunks * ChunkSize));
909+
} else {
910+
constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32;
911+
simd_mask_type<N1> Pred(0);
912+
Pred.template select<RemN, 1>() = 1;
913+
simd<T, N1> Vals(0);
914+
Vals.template select<RemN, 1>() =
915+
Tmp.template select<RemN, 1>(NumChunks * ChunkSize);
916+
simd<uint32_t, N1> Offsets(0u, sizeof(T));
917+
scatter<T, N1>(addr + (NumChunks * ChunkSize), Offsets, Vals, Pred);
918+
}
919+
}
837920
}
838921
}
839922

840923
template <typename T, int N, class T1, class SFINAE>
841-
template <typename AccessorT, typename Flags, typename>
924+
template <typename AccessorT, typename Flags, int ChunkSize, typename>
842925
ESIMD_INLINE EnableIfAccessor<AccessorT, accessor_mode_cap::can_write,
843926
sycl::access::target::global_buffer, void>
844927
simd_obj_impl<T, N, T1, SFINAE>::copy_to(AccessorT acc, uint32_t offset,
845928
Flags) const SYCL_ESIMD_FUNCTION {
846929
constexpr unsigned Size = sizeof(T) * N;
847930
constexpr unsigned Align = Flags::template alignment<T1>;
848931

932+
constexpr unsigned BlockSize = OperandSize::OWORD * 8;
933+
constexpr unsigned NumBlocks = Size / BlockSize;
934+
constexpr unsigned RemSize = Size % BlockSize;
935+
936+
simd<T, N> Tmp = data();
849937
if constexpr (Align >= OperandSize::OWORD && Size % OperandSize::OWORD == 0 &&
850-
detail::isPowerOf2(Size / OperandSize::OWORD)) {
851-
block_store<T, N, AccessorT>(acc, offset, cast_this_to_derived());
938+
detail::isPowerOf2(RemSize / OperandSize::OWORD)) {
939+
if constexpr (NumBlocks > 0) {
940+
constexpr unsigned BlockN = BlockSize / sizeof(T);
941+
ForHelper<NumBlocks>::unroll([BlockN, acc, offset, &Tmp](unsigned Block) {
942+
block_store<T, BlockN, AccessorT>(
943+
acc, offset + (Block * BlockSize),
944+
Tmp.template select<BlockN, 1>(Block * BlockN));
945+
});
946+
}
947+
if constexpr (RemSize > 0) {
948+
constexpr unsigned RemN = RemSize / sizeof(T);
949+
constexpr unsigned BlockN = BlockSize / sizeof(T);
950+
block_store<T, RemN, AccessorT>(
951+
acc, offset + (NumBlocks * BlockSize),
952+
Tmp.template select<RemN, 1>(NumBlocks * BlockN));
953+
}
852954
} else if constexpr (sizeof(T) == 8) {
853-
constexpr unsigned AlignUH =
854-
(N * 4) % Align == 0 ? Align : std::min(Align, 4u);
855-
simd<T, N> Tmp = data();
856-
simd<int32_t, N> LH =
857-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(0);
858-
simd<int32_t, N> UH =
859-
Tmp.template bit_cast_view<int32_t>().template select<N, 1>(N);
860-
LH.copy_to(acc, offset, Flags{});
861-
UH.copy_to(acc, offset + N * 4, overaligned<AlignUH>);
862-
} else if constexpr (N == 1 || N == 8 || N == 16 || N == 32) {
863-
simd<uint32_t, N> offsets(0u, sizeof(T));
864-
scatter<T, N, AccessorT>(acc, offsets, cast_this_to_derived().data(),
865-
offset);
955+
simd<int32_t, N * 2> BC = Tmp.template bit_cast_view<int32_t>();
956+
BC.copy_to(acc, offset, Flags{});
866957
} else {
867-
constexpr int N1 = N < 8 ? 8 : N < 16 ? 16 : 32;
868-
simd_mask_type<N1> pred(0);
869-
pred.template select<N, 1>() = 1;
870-
simd<T, N1> vals(0);
871-
vals.template select<N, 1>() = cast_this_to_derived().data();
872-
simd<uint32_t, N1> offsets(0u, sizeof(T));
873-
scatter<T, N1, AccessorT>(acc, offsets, vals, offset, pred);
958+
constexpr unsigned NumChunks = N / ChunkSize;
959+
if constexpr (NumChunks > 0) {
960+
simd<uint32_t, ChunkSize> Offsets(0u, sizeof(T));
961+
ForHelper<NumChunks>::unroll([acc, offset, &Offsets,
962+
&Tmp](unsigned Block) {
963+
scatter<T, ChunkSize, AccessorT>(
964+
acc, Offsets, Tmp.template select<ChunkSize, 1>(Block * ChunkSize),
965+
offset + (Block * ChunkSize * sizeof(T)));
966+
});
967+
}
968+
constexpr unsigned RemN = N % ChunkSize;
969+
if constexpr (RemN > 0) {
970+
if constexpr (RemN == 1 || RemN == 8 || RemN == 16) {
971+
simd<uint32_t, RemN> Offsets(0u, sizeof(T));
972+
scatter<T, RemN, AccessorT>(
973+
acc, Offsets, Tmp.template select<RemN, 1>(NumChunks * ChunkSize),
974+
offset + (NumChunks * ChunkSize * sizeof(T)));
975+
} else {
976+
constexpr int N1 = RemN < 8 ? 8 : RemN < 16 ? 16 : 32;
977+
simd_mask_type<N1> Pred(0);
978+
Pred.template select<RemN, 1>() = 1;
979+
simd<T, N1> Vals(0);
980+
Vals.template select<RemN, 1>() =
981+
Tmp.template select<RemN, 1>(NumChunks * ChunkSize);
982+
simd<uint32_t, N1> Offsets(0u, sizeof(T));
983+
scatter<T, N1, AccessorT>(acc, Offsets, Vals,
984+
offset + (NumChunks * ChunkSize * sizeof(T)),
985+
Pred);
986+
}
987+
}
874988
}
875989
}
876990
} // namespace detail

0 commit comments

Comments
 (0)