From 52874a92fa9f3e4625ec08321e6a94c63f0e2798 Mon Sep 17 00:00:00 2001 From: chende Date: Sat, 11 May 2024 05:16:06 +0000 Subject: [PATCH] finish interval grad kernel. --- .../user/kernels/fused_attention_kernels.cu | 82 +++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/oneflow/user/kernels/fused_attention_kernels.cu b/oneflow/user/kernels/fused_attention_kernels.cu index 8ec00615d6d..f54964c52cf 100644 --- a/oneflow/user/kernels/fused_attention_kernels.cu +++ b/oneflow/user/kernels/fused_attention_kernels.cu @@ -1132,7 +1132,79 @@ template __global__ void IntervalGradKernel( FusedApplyRotaryEmbParam param) { - printf("IntervalGradKernel TODO!\n"); + // printf("IntervalGradKernel TODO!\n"); + for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x; + packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) { + using LoadPack = cuda::elementwise::Packed; + IndexType offset = packed_offset * PackSize; + IndexType index[num_dims]; // b, m, h, k + + IndexType temp_offset = offset; + + for (int i = 0; i < num_dims - 1; i++) { + IndexType ref_stride = param.ref_stride[i]; + IndexType idx = temp_offset / ref_stride; + index[i] = idx; + temp_offset = temp_offset - idx * ref_stride; + } + index[num_dims - 1] = temp_offset; + + IndexType x_offset = param.x_offset; + IndexType out_offset = 0; +#pragma unroll + for (int i = 0; i < num_dims; i++) { + x_offset = x_offset + param.x_stride[i] * index[i]; + out_offset = out_offset + param.out_stride[i] * index[i]; + } + const LoadPack x_vec = *reinterpret_cast(param.x + x_offset); + + const IndexType k_index = index[num_dims - 1]; + if (k_index < param.rotary_size) { + const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0; + const IndexType b_index = index[0], m_index = index[1]; + const IndexType position_id_offset = b_index * param.position_b_stride + + position_rotate_index * param.position_rotate_stride + + m_index; + + const PositionType position = + param.position_ids ? param.position_ids[position_id_offset] : m_index; + const IndexType actual_k_index = k_index % param.actual_rotary_size; + const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index; + + LoadPack cos_vec, sin_vec, out_vec; + + if (param.cos && param.sin) { + cos_vec = *reinterpret_cast(param.cos + sinuous_offset); + sin_vec = *reinterpret_cast(param.sin + sinuous_offset); + } else { + const IndexType actual_ndim = param.rotary_size / rotary_emb_dim; +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + T val = position + * expf(2.0f * static_cast(((actual_k_index >> 1) + i)) + * param.inv_actual_rotary_size * logf(param.theta)); + T cos_val = cosf(val); + T sin_val = sinf(val); + cos_vec.elem[i * 2] = cos_val; + cos_vec.elem[i * 2 + 1] = cos_val; + sin_vec.elem[i * 2] = sin_val; + sin_vec.elem[i * 2 + 1] = sin_val; + } + } + +#pragma unroll + for (int i = 0; i < PackSize / 2; i++) { + out_vec.elem[i * 2] = + x_vec.elem[i * 2] * cos_vec.elem[i * 2] + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1]; + out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1] + - x_vec.elem[i * 2] * sin_vec.elem[i * 2]; + } + + *(reinterpret_cast(param.out + out_offset)) = out_vec; + } else { + *(reinterpret_cast(param.out + out_offset)) = x_vec; + } + } } template param.rotate_stride) - ? static_cast(*(param.x + x_offset + param.rotate_stride)) - : -*(param.x + x_offset - param.rotate_stride); + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else if (k_index < param.k1) { x_vec.elem[0] = *(param.x + x_offset); x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride) - ? static_cast(*(param.x + x_offset + param.rotate_stride)) - : -*(param.x + x_offset - param.rotate_stride); + ? *(param.x + x_offset + param.rotate_stride) + : static_cast(-*(param.x + x_offset - param.rotate_stride)); out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1]; } else { out_val = *(param.x + x_offset);