@@ -3230,6 +3230,65 @@ static struct ggml_tensor * llm_build_ffn(
32303230 return cur;
32313231}
32323232
3233+ enum llm_rope_type {
3234+ LLM_ROPE,
3235+ LLM_ROPE_NEOX,
3236+ LLM_ROPE_GLM,
3237+ };
3238+
3239+ // Persimmon: n_rot = n_embd_head/2
3240+ // Other: n_rot = n_embd_head
3241+ static void llm_build_k_shift (
3242+ const llama_context & lctx,
3243+ struct ggml_context * ctx,
3244+ struct ggml_cgraph * graph,
3245+ int64_t n_rot,
3246+ llm_rope_type type,
3247+ const llm_build_cb & cb) {
3248+ const auto & model = lctx.model ;
3249+ const auto & kv_self = lctx.kv_self ;
3250+ const auto & cparams = lctx.cparams ;
3251+
3252+ const auto & hparams = model.hparams ;
3253+
3254+ const int64_t n_head = hparams.n_head ;
3255+ const int64_t n_layer = hparams.n_layer ;
3256+ const int64_t n_embd_gqa = hparams.n_embd_gqa ();
3257+ const int64_t n_embd_head = hparams.n_embd_head ();
3258+
3259+ const int64_t n_ctx = lctx.cparams .n_ctx ;
3260+
3261+ const float freq_base = cparams.rope_freq_base ;
3262+ const float freq_scale = cparams.rope_freq_scale ;
3263+
3264+ GGML_ASSERT (n_embd_head % n_rot == 0 );
3265+
3266+ struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx);
3267+ cb (K_shift, " K_shift" , -1 );
3268+
3269+ int rope_type = 0 ;
3270+
3271+ switch (type) {
3272+ case LLM_ROPE: rope_type = 0 ; break ;
3273+ case LLM_ROPE_NEOX: rope_type = 2 ; break ;
3274+ case LLM_ROPE_GLM: rope_type = 4 ; break ;
3275+ };
3276+
3277+ for (int il = 0 ; il < n_layer; ++il) {
3278+ struct ggml_tensor * tmp =
3279+ // we rotate only the first n_rot dimensions
3280+ ggml_rope_custom_inplace (ctx,
3281+ ggml_view_3d (ctx, kv_self.k ,
3282+ n_rot, n_head, n_ctx,
3283+ ggml_element_size (kv_self.k )*n_embd_head,
3284+ ggml_element_size (kv_self.k )*n_embd_gqa,
3285+ ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3286+ K_shift, n_rot, rope_type, 0 , freq_base, freq_scale);
3287+ cb (tmp, " K_shifted" , il);
3288+ ggml_build_forward_expand (graph, tmp);
3289+ }
3290+ }
3291+
32333292static struct ggml_cgraph * llm_build_llama (
32343293 llama_context & lctx,
32353294 const llama_batch & batch,
@@ -3308,21 +3367,7 @@ static struct ggml_cgraph * llm_build_llama(
33083367
33093368 // shift the entire K-cache if needed
33103369 if (do_rope_shift) {
3311- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3312- cb (K_shift, " K_shift" , -1 );
3313-
3314- for (int il = 0 ; il < n_layer; ++il) {
3315- struct ggml_tensor * tmp =
3316- ggml_rope_custom_inplace (ctx0,
3317- ggml_view_3d (ctx0, kv_self.k ,
3318- n_embd_head, n_head_kv, n_ctx,
3319- ggml_element_size (kv_self.k )*n_embd_head,
3320- ggml_element_size (kv_self.k )*n_embd_gqa,
3321- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3322- K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale);
3323- cb (tmp, " K_shifted" , il);
3324- ggml_build_forward_expand (gf, tmp);
3325- }
3370+ llm_build_k_shift (lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
33263371 }
33273372
33283373 for (int il = 0 ; il < n_layer; ++il) {
@@ -3557,21 +3602,7 @@ static struct ggml_cgraph * llm_build_baichaun(
35573602
35583603 // shift the entire K-cache if needed
35593604 if (do_rope_shift) {
3560- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3561- cb (K_shift, " K_shift" , -1 );
3562-
3563- for (int il = 0 ; il < n_layer; ++il) {
3564- struct ggml_tensor * tmp =
3565- ggml_rope_custom_inplace (ctx0,
3566- ggml_view_3d (ctx0, kv_self.k ,
3567- n_embd_head, n_head_kv, n_ctx,
3568- ggml_element_size (kv_self.k )*n_embd_head,
3569- ggml_element_size (kv_self.k )*n_embd_gqa,
3570- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3571- K_shift, n_embd_head, 0 , 0 , freq_base, freq_scale);
3572- cb (tmp, " K_shifted" , il);
3573- ggml_build_forward_expand (gf, tmp);
3574- }
3605+ llm_build_k_shift (lctx, ctx0, gf, n_embd_head, LLM_ROPE, cb);
35753606 }
35763607
35773608 for (int il = 0 ; il < n_layer; ++il) {
@@ -3830,21 +3861,7 @@ static struct ggml_cgraph * llm_build_falcon(
38303861
38313862 // shift the entire K-cache if needed
38323863 if (do_rope_shift) {
3833- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
3834- cb (K_shift, " K_shift" , -1 );
3835-
3836- for (int il = 0 ; il < n_layer; ++il) {
3837- struct ggml_tensor * tmp =
3838- ggml_rope_custom_inplace (ctx0,
3839- ggml_view_3d (ctx0, kv_self.k ,
3840- n_embd_head, n_head_kv, n_ctx,
3841- ggml_element_size (kv_self.k )*n_embd_head,
3842- ggml_element_size (kv_self.k )*n_embd_gqa,
3843- ggml_element_size (kv_self.k )*n_embd_gqa*n_ctx*il),
3844- K_shift, n_embd_head, 2 , 0 , freq_base, freq_scale);
3845- cb (tmp, " K_shifted" , il);
3846- ggml_build_forward_expand (gf, tmp);
3847- }
3864+ llm_build_k_shift (lctx, ctx0, gf, n_embd_head, LLM_ROPE_NEOX, cb);
38483865 }
38493866
38503867 for (int il = 0 ; il < n_layer; ++il) {
@@ -4243,14 +4260,15 @@ static struct ggml_cgraph * llm_build_persimmon(
42434260 GGML_ASSERT (!!kv_self.ctx );
42444261
42454262 const auto & cparams = lctx.cparams ;
4263+
42464264 const int64_t n_embd = hparams.n_embd ;
42474265 const int64_t n_layer = hparams.n_layer ;
42484266 const int64_t n_ctx = cparams.n_ctx ;
42494267 const int64_t n_head_kv = hparams.n_head_kv ;
42504268 const int64_t n_head = hparams.n_head ;
42514269 const int64_t n_embd_head = hparams.n_embd_head ();
42524270 const int64_t n_embd_gqa = hparams.n_embd_gqa ();
4253- const size_t n_rot = n_embd_head / 2 ;
4271+ const int64_t n_rot = n_embd_head / 2 ;
42544272
42554273 const float freq_base = cparams.rope_freq_base ;
42564274 const float freq_scale = cparams.rope_freq_scale ;
@@ -4297,23 +4315,7 @@ static struct ggml_cgraph * llm_build_persimmon(
42974315 cb (KQ_mask, " KQ_mask" , -1 );
42984316
42994317 if (do_rope_shift) {
4300- struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx0, GGML_TYPE_I32, n_ctx);
4301- cb (K_shift, " K_shift" , -1 );
4302-
4303- for (int il = 0 ; il < n_layer; ++il) {
4304- struct ggml_tensor * tmp =
4305- // we rotate only the first n_rot dimensions.
4306- ggml_rope_custom_inplace (ctx0,
4307- ggml_view_3d (ctx0, kv_self.k ,
4308- n_rot, n_head, n_ctx,
4309- ggml_element_size (kv_self.k )*n_embd_gqa,
4310- ggml_element_size (kv_self.k )*n_embd_head,
4311- ggml_element_size (kv_self.k )*(n_embd_head*n_ctx*il)
4312- ),
4313- K_shift, n_rot, 2 , 0 , freq_base, freq_scale);
4314- cb (tmp, " K_shifted" , il);
4315- ggml_build_forward_expand (gf, tmp);
4316- }
4318+ llm_build_k_shift (lctx, ctx0, gf, n_rot, LLM_ROPE_NEOX, cb);
43174319 }
43184320
43194321 for (int il = 0 ; il < n_layer; ++il) {
@@ -5534,7 +5536,7 @@ static struct ggml_cgraph * llama_build_graph(
55345536#ifdef GGML_USE_CUBLAS
55355537 const bool do_offload = true ;
55365538#else
5537- const bool do_offload = false ;
5539+ const bool do_offload = true ; // TODO: set to false after finishing refactoring
55385540#endif
55395541
55405542 if (!do_offload) {
0 commit comments