Skip to content

Commit a69f9c2

Browse files
committed
gguf : add GPT-J model architecture
1 parent fd5db67 commit a69f9c2

File tree

1 file changed

+22
-4
lines changed

1 file changed

+22
-4
lines changed

gguf-py/gguf/gguf.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ class MODEL_TENSOR(IntEnum):
213213
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
214214
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
215215
},
216+
MODEL_ARCH.GPTJ: {
217+
MODEL_TENSOR.TOKEN_EMBD: "token_embd",
218+
MODEL_TENSOR.OUTPUT_NORM: "output_norm",
219+
MODEL_TENSOR.OUTPUT: "output",
220+
MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm",
221+
MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q",
222+
MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k",
223+
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
224+
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
225+
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
226+
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
227+
},
216228
MODEL_ARCH.GPT2: {
217229
# TODO
218230
},
@@ -237,7 +249,7 @@ class TensorNameMap:
237249
# Token embeddings
238250
MODEL_TENSOR.TOKEN_EMBD: (
239251
"gpt_neox.embed_in", # gptneox
240-
"transformer.wte", # gpt2 mpt
252+
"transformer.wte", # gpt2 gpt-j mpt
241253
"transformer.word_embeddings", # falcon
242254
"model.embed_tokens", # llama-hf
243255
"tok_embeddings", # llama-pth
@@ -258,14 +270,14 @@ class TensorNameMap:
258270
# Output
259271
MODEL_TENSOR.OUTPUT: (
260272
"embed_out", # gptneox
261-
"lm_head", # gpt2 mpt falcon llama-hf baichuan
273+
"lm_head", # gpt2 gpt-j mpt falcon llama-hf baichuan
262274
"output", # llama-pth
263275
),
264276

265277
# Output norm
266278
MODEL_TENSOR.OUTPUT_NORM: (
267279
"gpt_neox.final_layer_norm", # gptneox
268-
"transformer.ln_f", # gpt2 falcon
280+
"transformer.ln_f", # gpt2 gpt-j falcon
269281
"model.norm", # llama-hf baichuan
270282
"norm", # llama-pth
271283
"embeddings.LayerNorm", # bert
@@ -282,7 +294,7 @@ class TensorNameMap:
282294
# Attention norm
283295
MODEL_TENSOR.ATTN_NORM: (
284296
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
285-
"transformer.h.{bid}.ln_1", # gpt2
297+
"transformer.h.{bid}.ln_1", # gpt2 gpt-j
286298
"transformer.blocks.{bid}.norm_1", # mpt
287299
"transformer.h.{bid}.input_layernorm", # falcon7b
288300
"transformer.h.{bid}.ln_mlp", # falcon40b
@@ -309,20 +321,23 @@ class TensorNameMap:
309321
"model.layers.{bid}.self_attn.q_proj", # llama-hf
310322
"layers.{bid}.attention.wq", # llama-pth
311323
"encoder.layer.{bid}.attention.self.query", # bert
324+
"transformer.h.{bid}.attn.q_proj", # gpt-j
312325
),
313326

314327
# Attention key
315328
MODEL_TENSOR.ATTN_K: (
316329
"model.layers.{bid}.self_attn.k_proj", # llama-hf
317330
"layers.{bid}.attention.wk", # llama-pth
318331
"encoder.layer.{bid}.attention.self.key", # bert
332+
"transformer.h.{bid}.attn.k_proj", # gpt-j
319333
),
320334

321335
# Attention value
322336
MODEL_TENSOR.ATTN_V: (
323337
"model.layers.{bid}.self_attn.v_proj", # llama-hf
324338
"layers.{bid}.attention.wv", # llama-pth
325339
"encoder.layer.{bid}.attention.self.value", # bert
340+
"transformer.h.{bid}.attn.v_proj", # gpt-j
326341
),
327342

328343
# Attention output
@@ -334,6 +349,7 @@ class TensorNameMap:
334349
"model.layers.{bid}.self_attn.o_proj", # llama-hf
335350
"layers.{bid}.attention.wo", # llama-pth
336351
"encoder.layer.{bid}.attention.output.dense", # bert
352+
"transformer.h.{bid}.attn.out_proj", # gpt-j
337353
),
338354

339355
# Rotary embeddings
@@ -361,6 +377,7 @@ class TensorNameMap:
361377
"model.layers.{bid}.mlp.up_proj", # llama-hf
362378
"layers.{bid}.feed_forward.w3", # llama-pth
363379
"encoder.layer.{bid}.intermediate.dense", # bert
380+
"transformer.h.{bid}.mlp.fc_in", # gpt-j
364381
),
365382

366383
# Feed-forward gate
@@ -378,6 +395,7 @@ class TensorNameMap:
378395
"model.layers.{bid}.mlp.down_proj", # llama-hf
379396
"layers.{bid}.feed_forward.w2", # llama-pth
380397
"encoder.layer.{bid}.output.dense", # bert
398+
"transformer.h.{bid}.mlp.fc_out", # gpt-j
381399
),
382400
}
383401

0 commit comments

Comments
 (0)