Skip to content

Commit

Permalink
[Bug fix] Add rope_theta for llama config (microsoft#4480)
Browse files Browse the repository at this point in the history
* Add rope_theta for llama config

* Add rope_theta to bias_add_transform_0213

* Fix CI problems

* Add rope_theta to linear layer

---------

Co-authored-by: Michael Wyatt <michaelwyatt@microsoft.com>
Co-authored-by: Lev Kurilenko <113481193+lekurile@users.noreply.github.com>
  • Loading branch information
3 people authored and amaurya committed Feb 17, 2024
1 parent 978112c commit 680eb52
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 27 deletions.
20 changes: 16 additions & 4 deletions csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
unsigned num_heads,
unsigned head_size,
unsigned total_count,
float rope_theta,
int max_out_tokens)
{
constexpr int T_per_thread = granularity / sizeof(T);
Expand Down Expand Up @@ -61,7 +62,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
const int neuron_idx = base_neuron_idx + i;
if (neuron_idx < rotary_dim) {
float inv_freq = (float)((neuron_idx % half_dim) * 2) / (float)rotary_dim;
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_idx;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_idx;

float rotary_sign = (neuron_idx > (half_dim - 1) ? -1.0 : 1.0);
float q_rot = conversion::to<float>(q[i]) * rotary_sign;
Expand Down Expand Up @@ -95,6 +96,7 @@ __global__ void apply_rotary_pos_half(T* mixed_query,
num_heads, \
head_size, \
total_count, \
rope_theta, \
max_out_tokens);

#ifdef __HIP_PLATFORM_HCC__
Expand Down Expand Up @@ -136,6 +138,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned offset,
unsigned num_heads,
unsigned batch,
float rope_theta,
cudaStream_t stream,
int max_out_tokens)
{
Expand Down Expand Up @@ -176,9 +179,18 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
}
}

#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>( \
T*, T*, unsigned, unsigned, unsigned, unsigned, unsigned, unsigned, cudaStream_t, int);
#define INSTANTIATE_LAUNCH_ROTARY_POS_EMB(T) \
template void launch_apply_rotary_pos_emb<T>(T*, \
T*, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
unsigned, \
float, \
cudaStream_t, \
int);

INSTANTIATE_LAUNCH_ROTARY_POS_EMB(float);
#ifdef BF16_AVAILABLE
Expand Down
21 changes: 15 additions & 6 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
bool no_masking,
unsigned layer_id,
unsigned num_layers,
at::Tensor& alibi)
at::Tensor& alibi,
float rope_theta)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
Expand Down Expand Up @@ -493,7 +494,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
rotate_every_two,
InferenceContext::Instance().GetCurrentStream(),
3,
InferenceContext::Instance().GetMaxTokenLength());
InferenceContext::Instance().GetMaxTokenLength(),
rope_theta);
if (rotary_dim > 0 && rotate_half)
launch_apply_rotary_pos_emb(query_cont,
kv_cache,
Expand All @@ -503,6 +505,7 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
(is_prompt ? 0 : soft_len - 1),
heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());

Expand Down Expand Up @@ -1100,7 +1103,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
bool add_bias,
bool do_flash_attn,
int num_heads,
bool transposed_mode)
bool transposed_mode,
float rope_theta)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
Expand Down Expand Up @@ -1174,7 +1178,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
false,
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
input.size(1),
rope_theta);
return at::from_blob(final_output,
{3, input.size(0), num_heads, input.size(1), padded_head_size},
options);
Expand All @@ -1200,7 +1205,8 @@ at::Tensor ds_linear_layer(at::Tensor& input,
false,
InferenceContext::Instance().GetCurrentStream(),
3,
input.size(1));
input.size(1),
rope_theta);
return at::from_blob(
final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options);
// return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads,
Expand Down Expand Up @@ -1847,7 +1853,8 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
unsigned rotary_dim,
unsigned offset,
unsigned num_heads,
bool rotate_half)
bool rotate_half,
float rope_theta)
{
auto query_cont = mixed_query.contiguous();
auto key_cont = key_layer.contiguous();
Expand All @@ -1865,6 +1872,7 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
offset,
num_heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());
else
Expand All @@ -1876,6 +1884,7 @@ std::vector<at::Tensor> apply_rotary_pos_emb(at::Tensor& mixed_query,
offset,
num_heads,
bsz,
rope_theta,
InferenceContext::Instance().GetCurrentStream(),
InferenceContext::Instance().GetMaxTokenLength());
return {query_cont, key_cont};
Expand Down
25 changes: 16 additions & 9 deletions csrc/transformer/inference/csrc/transform.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ __global__ void bias_add_transform_0213(float* output,
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
int d0_stride = hidden_dim * seq_length;
int d1_stride = hidden_dim;
Expand Down Expand Up @@ -70,7 +71,7 @@ __global__ void bias_add_transform_0213(float* output,
#pragma unroll
for (int o = 0; o < 2; o++) {
float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id;
q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq));
q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq));
}
Expand Down Expand Up @@ -100,7 +101,8 @@ __global__ void bias_add_transform_0213(T* output, // q
bool rotate_half,
bool rotate_every_two,
int head_ext,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
using T2 =
typename std::conditional<std::is_same<T, __half>::value, __half2, __nv_bfloat162>::type;
Expand Down Expand Up @@ -147,7 +149,7 @@ __global__ void bias_add_transform_0213(T* output, // q
#pragma unroll
for (int o = 0; o < 4; o++) {
float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3);
inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id;
inv_freq = 1.0 / powf(rope_theta, inv_freq) * (float)seq_id;
float q_data[2];
q_data[0] = conversion::to<float>(q_h[o].x);
q_data[1] = conversion::to<float>(q_h[o].y);
Expand Down Expand Up @@ -181,7 +183,8 @@ void launch_bias_add_transform_0213<float>(float* output,
bool rotate_every_two,
cudaStream_t stream,
int trans_count,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
hidden_dim >>= 2;
int head_ext = (hidden_dim - 1) / MAX_THREADS + 1;
Expand All @@ -204,7 +207,8 @@ void launch_bias_add_transform_0213<float>(float* output,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
max_out_tokens,
rope_theta);
}

template <typename T>
Expand All @@ -225,7 +229,8 @@ void launch_bias_add_transform_0213(T* output,
bool rotate_every_two,
cudaStream_t stream,
int trans_count,
int max_out_tokens)
int max_out_tokens,
float rope_theta)
{
hidden_dim >>= 3;
int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1;
Expand All @@ -247,7 +252,8 @@ void launch_bias_add_transform_0213(T* output,
rotate_half,
rotate_every_two,
head_ext,
max_out_tokens);
max_out_tokens,
rope_theta);
}

#define INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(T) \
Expand All @@ -268,7 +274,8 @@ void launch_bias_add_transform_0213(T* output,
bool, \
cudaStream_t, \
int, \
int)
int, \
float)

#ifdef BF16_AVAILABLE
INSTANTIATE_LAUNCH_BIAS_ADD_TRANSFORM_0213(__nv_bfloat16);
Expand Down
4 changes: 3 additions & 1 deletion csrc/transformer/inference/includes/inference_cuda_layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ void launch_apply_rotary_pos_emb(T* mixed_query,
unsigned offset,
unsigned num_heads,
unsigned batch,
float rope_theta,
cudaStream_t stream,
int max_out_tokens);

Expand Down Expand Up @@ -207,7 +208,8 @@ void launch_bias_add_transform_0213(T* outputs,
bool rotate_every_two,
cudaStream_t stream,
int trans_count,
int max_out_tokens);
int max_out_tokens,
float rope_theta);
template <typename T>
void pad_data(T* padded_output,
T* output,
Expand Down
1 change: 1 addition & 0 deletions deepspeed/module_inject/containers/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def create_module(self, config=None):
_config.rotate_half = True
_config.rotate_every_two = False
_config.rotary_dim = self.hidden_size // self.num_attention_heads
_config.rope_theta = self.policy.client_module.self_attn.rope_theta
self.module = DeepSpeedGPTInference(_config, mp_group=self.mp_group)

return self.module
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/ops/transformer/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(self,
transposed_mode=False,
use_triton=False,
triton_autotune=False,
num_kv=-1):
num_kv=-1,
rope_theta=10000):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(self,
self.use_triton = use_triton
self.triton_autotune = triton_autotune
self.num_kv = num_kv
self.rope_theta = rope_theta

@classmethod
def from_dict(cls, json_object):
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/ops/transformer/inference/op_binding/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(self, config: DeepSpeedInferenceConfig):
except AttributeError:
self.linear_func = self.linear_fallback

def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose):
def linear_fallback(self, input, weight, bias, add_bias, do_flash_attn, num_heads, transpose, rope_theta):
raise NotImplementedError

def forward(self,
Expand All @@ -44,7 +44,7 @@ def forward(self,
external_cache: bool = None,
num_layers: int = None):
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
self.config.transposed_mode, self.config.rope_theta)
return qkv_out

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def __init__(self, config: DeepSpeedInferenceConfig):
except AttributeError:
self.softmax_context_func = self.softmax_context_fallback

def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, roteate_every_two, heads,
norm_factor, triangular_masking, local_attention, window_size, no_masking, layer_id,
num_layers, alibi):
def softmax_context_fallback(self, query_key_value, attn_mask, rotary_dim, rotate_half, rotate_every_two, heads,
num_kv, norm_factor, triangular_masking, local_attention, window_size, no_masking,
layer_id, num_layers, alibi, rope_theta):
raise NotImplementedError

def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads: int, num_kv: int,
Expand All @@ -41,6 +41,7 @@ def forward(self, query_key_value: torch.Tensor, attn_mask: torch.Tensor, heads:
output = self.softmax_context_func(query_key_value, attn_mask, self.config.rotary_dim, self.config.rotate_half,
self.config.rotate_every_two, heads, num_kv, norm_factor,
self.config.triangular_masking, self.config.local_attention,
self.config.window_size, no_masking, layer_id, num_layers, alibi)
self.config.window_size, no_masking, layer_id, num_layers, alibi,
self.config.rope_theta)

return output

0 comments on commit 680eb52

Please sign in to comment.