Skip to content

Commit

Permalink
finish interval grad kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
cccddd77 committed May 11, 2024
1 parent 4179b3d commit 52874a9
Showing 1 changed file with 77 additions and 5 deletions.
82 changes: 77 additions & 5 deletions oneflow/user/kernels/fused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,79 @@ template<typename T, typename PositionType, typename IndexType, size_t PackSize,
size_t rotary_emb_dim>
__global__ void IntervalGradKernel(
FusedApplyRotaryEmbParam<T, PositionType, IndexType, num_dims, rotary_emb_dim> param) {
printf("IntervalGradKernel TODO!\n");
// printf("IntervalGradKernel TODO!\n");
for (IndexType packed_offset = threadIdx.x + blockIdx.x * blockDim.x;
packed_offset < param.num_elements; packed_offset += blockDim.x * gridDim.x) {
using LoadPack = cuda::elementwise::Packed<T, PackSize>;
IndexType offset = packed_offset * PackSize;
IndexType index[num_dims]; // b, m, h, k
IndexType temp_offset = offset;
for (int i = 0; i < num_dims - 1; i++) {
IndexType ref_stride = param.ref_stride[i];
IndexType idx = temp_offset / ref_stride;
index[i] = idx;
temp_offset = temp_offset - idx * ref_stride;
}
index[num_dims - 1] = temp_offset;
IndexType x_offset = param.x_offset;
IndexType out_offset = 0;
#pragma unroll
for (int i = 0; i < num_dims; i++) {
x_offset = x_offset + param.x_stride[i] * index[i];
out_offset = out_offset + param.out_stride[i] * index[i];
}
const LoadPack x_vec = *reinterpret_cast<const LoadPack*>(param.x + x_offset);
const IndexType k_index = index[num_dims - 1];
if (k_index < param.rotary_size) {
const IndexType position_rotate_index = (k_index >= param.k0) ? 1 : 0;
const IndexType b_index = index[0], m_index = index[1];
const IndexType position_id_offset = b_index * param.position_b_stride
+ position_rotate_index * param.position_rotate_stride
+ m_index;
const PositionType position =
param.position_ids ? param.position_ids[position_id_offset] : m_index;
const IndexType actual_k_index = k_index % param.actual_rotary_size;
const IndexType sinuous_offset = position * param.sinuous_m_stride + actual_k_index;
LoadPack cos_vec, sin_vec, out_vec;
if (param.cos && param.sin) {
cos_vec = *reinterpret_cast<const LoadPack*>(param.cos + sinuous_offset);
sin_vec = *reinterpret_cast<const LoadPack*>(param.sin + sinuous_offset);
} else {
const IndexType actual_ndim = param.rotary_size / rotary_emb_dim;
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
T val = position
* expf(2.0f * static_cast<float>(((actual_k_index >> 1) + i))
* param.inv_actual_rotary_size * logf(param.theta));
T cos_val = cosf(val);
T sin_val = sinf(val);
cos_vec.elem[i * 2] = cos_val;
cos_vec.elem[i * 2 + 1] = cos_val;
sin_vec.elem[i * 2] = sin_val;
sin_vec.elem[i * 2 + 1] = sin_val;
}
}
#pragma unroll
for (int i = 0; i < PackSize / 2; i++) {
out_vec.elem[i * 2] =
x_vec.elem[i * 2] * cos_vec.elem[i * 2] + x_vec.elem[i * 2 + 1] * sin_vec.elem[i * 2 + 1];
out_vec.elem[i * 2 + 1] = x_vec.elem[i * 2 + 1] * cos_vec.elem[i * 2 + 1]
- x_vec.elem[i * 2] * sin_vec.elem[i * 2];
}
*(reinterpret_cast<LoadPack*>(param.out + out_offset)) = out_vec;
} else {
*(reinterpret_cast<LoadPack*>(param.out + out_offset)) = x_vec;
}
}
}
template<typename T, typename PositionType, typename IndexType, size_t num_dims,
Expand Down Expand Up @@ -1271,14 +1343,14 @@ __global__ void PlaneGradKernel(
if (k_index < param.k0) {
x_vec.elem[0] = *(param.x + x_offset);
x_vec.elem[1] = (param.k0 - k_index > param.rotate_stride)
? static_cast<T>(*(param.x + x_offset + param.rotate_stride))
: -*(param.x + x_offset - param.rotate_stride);
? *(param.x + x_offset + param.rotate_stride)
: static_cast<T>(-*(param.x + x_offset - param.rotate_stride));
out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1];
} else if (k_index < param.k1) {
x_vec.elem[0] = *(param.x + x_offset);
x_vec.elem[1] = (param.k1 - k_index > param.rotate_stride)
? static_cast<T>(*(param.x + x_offset + param.rotate_stride))
: -*(param.x + x_offset - param.rotate_stride);
? *(param.x + x_offset + param.rotate_stride)
: static_cast<T>(-*(param.x + x_offset - param.rotate_stride));
out_val = cos_val * x_vec.elem[0] + sin_val * x_vec.elem[1];
} else {
out_val = *(param.x + x_offset);
Expand Down

0 comments on commit 52874a9

Please sign in to comment.