Skip to content

Commit

Permalink
remove repeat from gptj
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 authored and AmineDiro committed Aug 15, 2023
1 parent 3cb5061 commit cbe1e51
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions crates/models/gptj/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ impl KnownModel for GptJ {
// norm
let mut current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
&ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, &current), &current),
&ctx0.op_repeat(&self.layers[il].ln_1_b, &current),
&ctx0.op_mul(&current, &self.layers[il].ln_1_g),
&self.layers[il].ln_1_b,
);

let input_sa = current.share();
Expand Down Expand Up @@ -241,19 +241,13 @@ impl KnownModel for GptJ {
let ff_in = current.share();

current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa);
current = ctx0.op_add(
&ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, &current),
&current,
);
current = ctx0.op_add(&current, &self.layers[il].c_mlp_fc_b);

current = ctx0.op_gelu(&current);

// feed-forward projection
current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, &current);
current = ctx0.op_add(
&ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, &current),
&current,
);
current = ctx0.op_add(&current, &self.layers[il].c_mlp_proj_b);

current = ctx0.op_add(&current, &ff_in);

Expand All @@ -263,16 +257,13 @@ impl KnownModel for GptJ {

// norm
input_layer = ctx0.op_norm(&input_layer);
input_layer = ctx0.op_add(
&ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
&ctx0.op_repeat(&self.ln_f_b, &input_layer),
);
input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b);

let embeddings_tensor: ggml::Tensor = input_layer.share();

// lm_head
input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);
input_layer = ctx0.op_add(&ctx0.op_repeat(&self.lmh_b, &input_layer), &input_layer);
input_layer = ctx0.op_add(&input_layer, &self.lmh_b);

(
gf,
Expand Down

0 comments on commit cbe1e51

Please sign in to comment.