Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions paddle/phi/kernels/primitive/datamover_primitives_xpu2.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ __device__ __forceinline__ void WriteData(T _global_ptr_* dst,
T* src,
int num) {
if (num > 0) {
mfence_local();
LM2GM(src, dst, num * sizeof(T));
}
}
Expand Down Expand Up @@ -387,6 +388,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idy * stride_ny, in_temp, sizeof(Tx));
dst[idy] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -398,6 +400,7 @@ __device__ __inline__ void ReadData(Ty* dst,
break;
}
}
mfence_local();
GM2LM(src + thread_offset + idx * stride_nx, in_temp, sizeof(Tx));
dst[idx] = static_cast<Ty>(in_temp[0]);
}
Expand All @@ -412,6 +415,7 @@ __device__ __inline__ void ReadData(Ty* dst,
}
}
int fix = thread_offset + idx * stride_nx + idy * stride_ny;
mfence_local();
GM2LM(src + fix, in_temp, sizeof(Tx));
dst[idy * NX + idx] = static_cast<Ty>(in_temp[0]);
}
Expand Down Expand Up @@ -484,14 +488,13 @@ template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ __inline__ void ReadData(T* dst,
const T _global_ptr_* src,
int num) {
mfence_local();
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

402行, in_temp在for循环中,且看代码NX应该有可能>1,那么在403行scalar read之后,下一次循环则发生GM2LM, 所以应该在402行之前是否应该mfence一下

Copy link
Contributor

@tiancaitzp tiancaitzp Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

391,416, 494, 515, 571,627, 720行的in_temp看着也是同样,最好用模拟器的mfence检查工具跑一下,这样最保险

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

}
} else { // core_num() * NX < num
Expand All @@ -505,13 +508,12 @@ __device__ __inline__ void ReadData(T* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();
if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
GM2LM(src + thread_offset + idx, in_temp, sizeof(T));
dst[idx] = in_temp[0];
GM2LM(src + thread_offset + idx, dst + idx, sizeof(T));
}
}
} else { // core_num() * read_lens < num
Expand Down Expand Up @@ -607,8 +609,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
int stride_ny) {
uint32_t thread_offset = block_offset + core_id();
uint32_t index_src = 0;
__local__ T in_temp[1];

mfence_local();
#pragma unroll
for (int ny = 0; ny < NY; ++ny) {
#pragma unroll
Expand All @@ -621,8 +622,7 @@ __device__ __inline__ void ReadDataBc(T* dst,
}
}
index_src = config(index_output);
GM2LM(src + index_src, in_temp, sizeof(T));
dst[nx + ny * NX] = in_temp[0];
GM2LM(src + index_src, dst + nx + ny * NX, sizeof(T));
}
}
}
Expand Down Expand Up @@ -698,8 +698,10 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[ny] = static_cast<Ty>(func(in_temp[0]));

thread_offset += stride_ny;
}
} else {
Expand All @@ -714,6 +716,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
mfence_local();
GM2LM(src + index_src, in_temp, sizeof(Tx));
dst[nx + ny * NX] = static_cast<Ty>(func(in_temp[0]));
thread_offset += stride_ny;
Expand Down Expand Up @@ -749,37 +752,34 @@ __device__ void WriteData(T _global_ptr_* dst,
int num,
int read_lens) {
int thread_offset = core_id() * read_lens;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * read_lens > num
#pragma unroll
for (int idx = 0; idx < read_lens; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
mfence();
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * read_lens < num
mfence();
LM2GM(src, dst + thread_offset, read_lens * sizeof(T));
}
}

template <typename T, int NX, int NY, int BlockSize, bool IsBoundary>
__device__ void WriteData(T _global_ptr_* dst, const T* src, int num) {
int thread_offset = core_id() * NX;
__local__ T in_temp[1];
mfence_local();

if (IsBoundary) { // core_num() * NX > num
#pragma unroll
for (int idx = 0; idx < NX; ++idx) {
if (idx + thread_offset < num) {
in_temp[0] = src[idx];
LM2GM(in_temp, dst + idx + thread_offset, sizeof(T));
LM2GM(src + idx, dst + idx + thread_offset, sizeof(T));
}
}
} else { // core_num() * NX < num
mfence_local();
LM2GM(src, dst + thread_offset, NX * sizeof(T));
}
}
Expand Down Expand Up @@ -831,10 +831,12 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
if (IsBoundary) {
if (left_size_nx > 0) {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else {
in_temp[0] = static_cast<Ty>(src[0]);
mfence_local();
LM2GM(in_temp, dst + thread_offset, sizeof(Ty));
}
} else if (NX == 1) {
Expand All @@ -847,6 +849,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idy]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idy * stride_ny, sizeof(Ty));
}
} else if (NY == 1) { // for NY == 1 and NX != 1
Expand All @@ -859,6 +862,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}

in_temp[0] = static_cast<Ty>(src[idx]);
mfence_local();
LM2GM(in_temp, dst + thread_offset + idx * stride_nx, sizeof(Ty));
}
} else { // for NX != 1 and NY != 1
Expand All @@ -877,6 +881,7 @@ __device__ __inline__ void WriteData(Ty _global_ptr_* dst,
}
}
in_temp[0] = static_cast<Ty>(src[idx + idy * NX]);
mfence_local();
LM2GM(in_temp,
dst + thread_offset + idx * stride_nx + idy * stride_ny,
sizeof(Ty));
Expand Down Expand Up @@ -1029,6 +1034,7 @@ __device__ __inline__ void ReadDataBc1NMn(
for (int i = 0; i < last_col; i++) {
dst[i] = in_temp;
}
mfence_local();
GM2LM(src + index_base + 1, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down Expand Up @@ -1083,6 +1089,7 @@ __device__ __inline__ void ReadDataBc1N1Mnk(
} else {
next_part_index = 0;
}
mfence_local();
GM2LM(src + next_part_index, &in_temp, sizeof(T));
for (int i = 0; i < read_lens - last_col; i++) {
dst[last_col + i] = in_temp;
Expand Down Expand Up @@ -1169,6 +1176,7 @@ __device__ __inline__ void ReadDataBcCanNotCmp(
if (index_src >= index_base && index_src < index_base + cache_size) {
in_temp = src_temp[index_src - index_base];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在1040行对in_temp发生了scalar read,注意一下1042行,这里GM2LM之前需要mfence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经添加

} else {
mfence_local();
GM2LM(src + index_src, &in_temp, sizeof(T));
}
dst[nx] = in_temp;
Expand Down