@@ -83,6 +83,7 @@ const char * llm_type_name(llm_type type) {
8383 case LLM_TYPE_32B: return "32B";
8484 case LLM_TYPE_34B: return "34B";
8585 case LLM_TYPE_35B: return "35B";
86+ case LLM_TYPE_36B: return "36B";
8687 case LLM_TYPE_40B: return "40B";
8788 case LLM_TYPE_65B: return "65B";
8889 case LLM_TYPE_70B: return "70B";
@@ -1288,6 +1289,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12881289 default: type = LLM_TYPE_UNKNOWN;
12891290 }
12901291 } break;
1292+ case LLM_ARCH_SEED_OSS:
1293+ {
1294+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1295+ switch (hparams.n_layer) {
1296+ case 64: type = LLM_TYPE_36B; break;
1297+ default: type = LLM_TYPE_UNKNOWN;
1298+ }
1299+ } break;
12911300 case LLM_ARCH_OLMOE:
12921301 {
12931302 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3967,6 +3976,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
39673976 layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
39683977 }
39693978 } break;
3979+ case LLM_ARCH_SEED_OSS:
3980+ {
3981+ const uint32_t head_dim = hparams.n_embd_head_k;
3982+ const int64_t n_qo_dim = n_head * head_dim;
3983+ const int64_t n_kv_dim = n_head_kv * head_dim;
3984+
3985+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3986+
3987+ // output
3988+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3989+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3990+ // if output is NULL, init from the input tok embed
3991+ if (output == NULL) {
3992+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3993+ }
3994+
3995+ for (int i = 0; i < n_layer; ++i) {
3996+ auto & layer = layers[i];
3997+
3998+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
3999+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
4000+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
4001+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
4002+
4003+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
4004+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4005+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4006+
4007+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4008+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4009+
4010+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4011+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4012+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4013+ }
4014+ } break;
4015+
39704016 case LLM_ARCH_OLMOE:
39714017 {
39724018 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -17934,6 +17980,137 @@ struct llm_build_lfm2 : public llm_graph_context {
1793417980 }
1793517981};
1793617982
17983+ struct llm_build_seed_oss : public llm_graph_context {
17984+ llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
17985+ const int64_t n_embd_head = hparams.n_embd_head_v;
17986+
17987+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
17988+ GGML_ASSERT(n_embd_head == hparams.n_rot);
17989+
17990+ ggml_tensor * cur;
17991+ ggml_tensor * inpL;
17992+
17993+ inpL = build_inp_embd(model.tok_embd);
17994+
17995+ // inp_pos - contains the positions
17996+ ggml_tensor * inp_pos = build_inp_pos();
17997+
17998+ auto * inp_attn = build_attn_inp_kv();
17999+
18000+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
18001+
18002+ ggml_tensor * inp_out_ids = build_inp_out_ids();
18003+
18004+ for (int il = 0; il < n_layer; ++il) {
18005+ ggml_tensor * inpSA = inpL;
18006+
18007+ // norm
18008+ cur = build_norm(inpL,
18009+ model.layers[il].attn_norm, NULL,
18010+ LLM_NORM_RMS, il);
18011+ cb(cur, "attn_norm", il);
18012+
18013+ // self-attention
18014+ {
18015+ // compute Q and K and RoPE them
18016+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
18017+ cb(Qcur, "Qcur", il);
18018+ if (model.layers[il].bq) {
18019+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
18020+ cb(Qcur, "Qcur", il);
18021+ }
18022+
18023+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
18024+ cb(Kcur, "Kcur", il);
18025+ if (model.layers[il].bk) {
18026+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
18027+ cb(Kcur, "Kcur", il);
18028+ }
18029+
18030+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
18031+ cb(Vcur, "Vcur", il);
18032+ if (model.layers[il].bv) {
18033+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
18034+ cb(Vcur, "Vcur", il);
18035+ }
18036+
18037+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
18038+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
18039+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
18040+
18041+ Qcur = ggml_rope_ext(
18042+ ctx0, Qcur, inp_pos, nullptr,
18043+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18044+ ext_factor, attn_factor, beta_fast, beta_slow
18045+ );
18046+
18047+ Kcur = ggml_rope_ext(
18048+ ctx0, Kcur, inp_pos, nullptr,
18049+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18050+ ext_factor, attn_factor, beta_fast, beta_slow
18051+ );
18052+
18053+ cb(Qcur, "Qcur", il);
18054+ cb(Kcur, "Kcur", il);
18055+ cb(Vcur, "Vcur", il);
18056+
18057+ cur = build_attn(inp_attn,
18058+ model.layers[il].wo, model.layers[il].bo,
18059+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
18060+ cb(cur, "attn_out", il);
18061+ }
18062+
18063+ if (il == n_layer - 1 && inp_out_ids) {
18064+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
18065+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
18066+ }
18067+
18068+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
18069+ cb(ffn_inp, "ffn_inp", il);
18070+
18071+ // feed-forward network
18072+ cur = build_norm(ffn_inp,
18073+ model.layers[il].attn_post_norm, NULL,
18074+ LLM_NORM_RMS, il);
18075+ cb(cur, "attn_post_norm", il);
18076+
18077+ cur = build_ffn(cur,
18078+ model.layers[il].ffn_up, NULL, NULL,
18079+ model.layers[il].ffn_gate, NULL, NULL,
18080+ model.layers[il].ffn_down, NULL, NULL,
18081+ NULL,
18082+ LLM_FFN_SILU, LLM_FFN_PAR, il);
18083+ cb(cur, "ffn_out", il);
18084+
18085+ cur = ggml_add(ctx0, cur, ffn_inp);
18086+ cb(cur, "ffn_out", il);
18087+
18088+ cur = build_cvec(cur, il);
18089+ cb(cur, "l_out", il);
18090+
18091+ // input for next layer
18092+ inpL = cur;
18093+ }
18094+
18095+ cur = inpL;
18096+
18097+ cur = build_norm(cur,
18098+ model.output_norm, NULL,
18099+ LLM_NORM_RMS, -1);
18100+
18101+ cb(cur, "result_norm", -1);
18102+ res->t_embd = cur;
18103+
18104+ // lm_head
18105+ cur = build_lora_mm(model.output, cur);
18106+
18107+ cb(cur, "result_output", -1);
18108+ res->t_logits = cur;
18109+
18110+ ggml_build_forward_expand(gf, cur);
18111+ }
18112+ };
18113+
1793718114template <bool iswa>
1793818115struct llm_build_smallthinker : public llm_graph_context{
1793918116 llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
@@ -18472,6 +18649,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1847218649 {
1847318650 llm = std::make_unique<llm_build_bailingmoe>(*this, params);
1847418651 } break;
18652+ case LLM_ARCH_SEED_OSS:
18653+ {
18654+ llm = std::make_unique<llm_build_seed_oss>(*this, params);
18655+ } break;
1847518656 case LLM_ARCH_DOTS1:
1847618657 {
1847718658 llm = std::make_unique<llm_build_dots1>(*this, params);
@@ -18530,6 +18711,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1853018711 return llm->res->get_gf();
1853118712}
1853218713
18714+
1853318715//
1853418716// interface implementation
1853518717//
@@ -18724,6 +18906,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1872418906 case LLM_ARCH_LFM2:
1872518907 case LLM_ARCH_SMALLTHINKER:
1872618908 case LLM_ARCH_GLM4_MOE:
18909+ case LLM_ARCH_SEED_OSS:
1872718910 return LLAMA_ROPE_TYPE_NEOX;
1872818911
1872918912 case LLM_ARCH_QWEN2VL:
0 commit comments