Skip to content

Commit

Permalink
GPT-J inference support (#1670)
Browse files Browse the repository at this point in the history
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
  • Loading branch information
RezaYazdaniAminabadi and jeffra authored Jan 8, 2022
1 parent 7e857aa commit 289c3f9
Show file tree
Hide file tree
Showing 10 changed files with 587 additions and 107 deletions.
129 changes: 129 additions & 0 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "custom_cuda_layers.h"

#include <cuda_profiler_api.h>

namespace cg = cooperative_groups;

__global__ void apply_rotary_pos_emb(float* mixed_query,
float* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);

int id = threadIdx.x;
int gid = id >> 5;
int lane = id & 0x1f;

unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid;
unsigned offset = head_id * head_size;

unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;

if (head_id < total_count) {
while (lane < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = mixed_query[offset + lane];
float k = key_layer[offset + lane];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset + lane] = q;
key_layer[offset + lane] = k;

lane += WARP_SIZE;
}
}
}

__global__ void apply_rotary_pos_emb(__half* mixed_query,
__half* key_layer,
unsigned rotary_dim,
unsigned seq_len,
unsigned seq_offset,
unsigned num_heads,
unsigned head_size,
unsigned total_count)
{
#if __CUDA_ARCH__ >= 700

unsigned head_id = blockIdx.x * blockDim.y + threadIdx.y;
if (head_id < total_count) {
unsigned offset = head_id * head_size + threadIdx.x;
unsigned tid = threadIdx.x;
unsigned seq_id = (head_id / num_heads) % seq_len + seq_offset;

cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);

while (tid < rotary_dim) {
float inv_freq = (float)((lane / 2) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
float q = (float)mixed_query[offset];
float k = (float)key_layer[offset];
float rotary_sign = (lane % 2 == 1 ? -1.0 : 1.0);
float q_rot = (q * rotary_sign);
float k_rot = (k * rotary_sign);
q_rot = g.shfl_xor(q_rot, 1);
k_rot = g.shfl_xor(k_rot, 1);
q = q * cosf(inv_freq) + q_rot * sinf(inv_freq);
k = k * cosf(inv_freq) + k_rot * sinf(inv_freq);

mixed_query[offset] = (__half)q;
key_layer[offset] = (__half)k;

tid += blockDim.x;
offset += blockDim.x;
}
}
#endif
}

template <typename T>
void launch_apply_rotary_pos_emb(T* mixed_query,
T* key_layer,
unsigned head_size,
unsigned seq_len,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
unsigned batch,
cudaStream_t stream)
{
int total_count = batch * num_heads * seq_len;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / MAX_WARP_NUM + 1); // (batch_size);

apply_rotary_pos_emb<<<grid_dims, block_dims, 0, stream>>>(
mixed_query, key_layer, rotary_dim, seq_len, offset, num_heads, head_size, total_count);
}

template void launch_apply_rotary_pos_emb<float>(float*,
float*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
template void launch_apply_rotary_pos_emb<__half>(__half*,
__half*,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
unsigned,
cudaStream_t);
104 changes: 104 additions & 0 deletions csrc/transformer/inference/csrc/gelu.cu
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,107 @@ template void launch_bias_residual<__half>(__half*,
int,
int,
cudaStream_t);

__global__ void gptj_residual_add(float* input,
float* output,
float* attn,
float* bias,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];

data.x += (out.x + res_vec.x + bias_data.x);
data.y += (out.y + res_vec.y + bias_data.y);
data.z += (out.z + res_vec.z + bias_data.z);
data.w += (out.w + res_vec.w + bias_data.w);

output_cast[offset] = data;
}
}

__global__ void gptj_residual_add(__half* input,
__half* output,
__half* attn,
__half* bias,
int total_count,
int intermediate_size)
{
#if __CUDA_ARCH__ >= 700

float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);

float2* bias_cast = reinterpret_cast<float2*>(bias);

int offset = blockIdx.x * blockDim.x + threadIdx.x;

if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];

float2 bias_vec = bias_cast[offset % intermediate_size];

__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);

float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);

float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);

float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);

float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);

low_data.x += (low_out.x + low_res.x + low_bias.x);
low_data.y += (low_out.y + low_res.y + low_bias.y);
high_data.x += (high_out.x + high_res.x + high_bias.x);
high_data.y += (high_out.y + high_res.y + high_bias.y);

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);

output_cast[offset] = vals_vec;
}
#endif
}

template <typename T>
void launch_gptj_residual_add(T* input,
T* output,
T* attn,
T* bias,
int hidden_dim,
int batch,
cudaStream_t stream)
{
int total_count = batch * hidden_dim / 4;
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);

gptj_residual_add<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, total_count, hidden_dim / 4);
}

template void
launch_gptj_residual_add<float>(float*, float*, float*, float*, int, int, cudaStream_t);
template void
launch_gptj_residual_add<__half>(__half*, __half*, __half*, __half*, int, int, cudaStream_t);
Loading

0 comments on commit 289c3f9

Please sign in to comment.