Skip to content

Commit

Permalink
cuda acceleration for gptj
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 authored and AmineDiro committed Aug 15, 2023
1 parent b77388e commit fa60f2f
Showing 1 changed file with 45 additions and 15 deletions.
60 changes: 45 additions & 15 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,24 +57,49 @@ impl KnownModel for GptJ {

// model-global weights
let wte = tl.load("transformer.wte.weight")?;
let ln_f_g = tl.load("transformer.ln_f.weight")?;
let ln_f_b = tl.load("transformer.ln_f.bias")?;
let lmh_g = tl.load("lm_head.weight")?;
let lmh_b = tl.load("lm_head.bias")?;

let backend = params.backend(0);

let ln_f_g = tl.load("transformer.ln_f.weight")?.transfer_to(backend);
let ln_f_b = tl.load("transformer.ln_f.bias")?.transfer_to(backend);
let lmh_g = tl.load("lm_head.weight")?.transfer_to(backend);
let lmh_b = tl.load("lm_head.bias")?.transfer_to(backend);

let mut layers = Vec::new();
for i in 0..hyperparameters.n_layer {
let backend = params.backend(i);

let layer = Layer {
ln_1_g: tl.load(&format!("transformer.h.{i}.ln_1.weight"))?,
ln_1_b: tl.load(&format!("transformer.h.{i}.ln_1.bias"))?,
c_attn_q_proj_w: tl.load(&format!("transformer.h.{i}.attn.q_proj.weight"))?,
c_attn_k_proj_w: tl.load(&format!("transformer.h.{i}.attn.k_proj.weight"))?,
c_attn_v_proj_w: tl.load(&format!("transformer.h.{i}.attn.v_proj.weight"))?,
c_attn_proj_w: tl.load(&format!("transformer.h.{i}.attn.out_proj.weight"))?,
c_mlp_fc_w: tl.load(&format!("transformer.h.{i}.mlp.fc_in.weight"))?,
c_mlp_fc_b: tl.load(&format!("transformer.h.{i}.mlp.fc_in.bias"))?,
c_mlp_proj_w: tl.load(&format!("transformer.h.{i}.mlp.fc_out.weight"))?,
c_mlp_proj_b: tl.load(&format!("transformer.h.{i}.mlp.fc_out.bias"))?,
ln_1_g: tl
.load(&format!("transformer.h.{i}.ln_1.weight"))?
.transfer_to(backend),
ln_1_b: tl
.load(&format!("transformer.h.{i}.ln_1.bias"))?
.transfer_to(backend),
c_attn_q_proj_w: tl
.load(&format!("transformer.h.{i}.attn.q_proj.weight"))?
.transfer_to(backend),
c_attn_k_proj_w: tl
.load(&format!("transformer.h.{i}.attn.k_proj.weight"))?
.transfer_to(backend),
c_attn_v_proj_w: tl
.load(&format!("transformer.h.{i}.attn.v_proj.weight"))?
.transfer_to(backend),
c_attn_proj_w: tl
.load(&format!("transformer.h.{i}.attn.out_proj.weight"))?
.transfer_to(backend),
c_mlp_fc_w: tl
.load(&format!("transformer.h.{i}.mlp.fc_in.weight"))?
.transfer_to(backend),
c_mlp_fc_b: tl
.load(&format!("transformer.h.{i}.mlp.fc_in.bias"))?
.transfer_to(backend),
c_mlp_proj_w: tl
.load(&format!("transformer.h.{i}.mlp.fc_out.weight"))?
.transfer_to(backend),
c_mlp_proj_b: tl
.load(&format!("transformer.h.{i}.mlp.fc_out.bias"))?
.transfer_to(backend),
};

layers.push(layer);
Expand Down Expand Up @@ -126,7 +151,7 @@ impl KnownModel for GptJ {
} = self.hyperparameters;

let outputs = session.compute(self.context.clone(), input_tokens, |builder| {
let ctx0 = builder.ctx0.borrow();
let mut ctx0 = builder.ctx0.borrow_mut();
let (memory_k_size, memory_v_size) = (
builder.memory_k.element_size(),
builder.memory_v.element_size(),
Expand All @@ -137,6 +162,8 @@ impl KnownModel for GptJ {

let mut gf = ggml::ComputationGraph::new();
for il in 0..n_layer {
ctx0.set_offloading(self.params.should_offload(il));

// norm
let mut current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
Expand Down Expand Up @@ -263,6 +290,9 @@ impl KnownModel for GptJ {

// lm_head
input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);

ctx0.set_offloading(false);

input_layer = ctx0.op_add(&input_layer, &self.lmh_b);

(
Expand Down

0 comments on commit fa60f2f

Please sign in to comment.