Skip to content

Commit

Permalink
remove repeat from gptneox
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 authored and AmineDiro committed Aug 15, 2023
1 parent cbe1e51 commit 519705b
Showing 1 changed file with 8 additions and 20 deletions.
28 changes: 8 additions & 20 deletions crates/models/gptneox/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,16 +159,13 @@ impl KnownModel for GptNeoX {
// self-attention
let mut current = ctx0.op_norm(&input_layer);
current = ctx0.op_add(
&ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, &current), &current),
&ctx0.op_repeat(&self.layers[il].ln_1_b, &current),
&ctx0.op_mul(&current, &self.layers[il].ln_1_g),
&self.layers[il].ln_1_b,
);

// self-attention compute QKV
current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, &current);
current = ctx0.op_add(
&ctx0.op_repeat(&self.layers[il].c_attn_attn_b, &current),
&current,
);
current = ctx0.op_add(&current, &self.layers[il].c_attn_attn_b);

let nb = current.get_nb()[1];
let f32_size = std::mem::size_of::<f32>();
Expand Down Expand Up @@ -269,10 +266,7 @@ impl KnownModel for GptNeoX {

// self-attention projection
current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, &current);
current = ctx0.op_add(
&ctx0.op_repeat(&self.layers[il].c_attn_proj_b, &current),
&current,
);
current = ctx0.op_add(&current, &self.layers[il].c_attn_proj_b);

// use the second scratch for the feed forward
ctx0.use_scratch(builder.get_scratch(1));
Expand Down Expand Up @@ -305,10 +299,7 @@ impl KnownModel for GptNeoX {
// normalize the output
input_layer = ctx0.op_norm(&input_layer);
// inpL = ln_f_g*inpL + ln_f_b
input_layer = ctx0.op_add(
&ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
&ctx0.op_repeat(&self.ln_f_b, &input_layer),
);
input_layer = ctx0.op_add(&ctx0.op_mul(&input_layer, &self.ln_f_g), &self.ln_f_b);

let embeddings_tensor: ggml::Tensor = input_layer.share();

Expand Down Expand Up @@ -470,16 +461,13 @@ fn feed_forward_network(context: &ggml::Context, layer: &Layer, input: &Tensor)
let mut current = context.op_norm(input);

//gain and bias
current = context.op_add(
&context.op_mul(&context.op_repeat(&layer.ln_2_g, &current), &current),
&context.op_repeat(&layer.ln_2_b, &current),
);
current = context.op_add(&context.op_mul(&current, &layer.ln_2_g), &layer.ln_2_b);

// apply weights
current = context.op_mul_mat(&layer.c_mlp_fc_w, &current);

// apply bias
current = context.op_add(&context.op_repeat(&layer.c_mlp_fc_b, &current), &current);
current = context.op_add(&current, &layer.c_mlp_fc_b);

// GELU activation
current = context.op_gelu(&current);
Expand All @@ -488,7 +476,7 @@ fn feed_forward_network(context: &ggml::Context, layer: &Layer, input: &Tensor)
// cur = proj_w*cur + proj_b
current = context.op_mul_mat(&layer.c_mlp_proj_w, &current);

current = context.op_add(&context.op_repeat(&layer.c_mlp_proj_b, &current), &current);
current = context.op_add(&current, &layer.c_mlp_proj_b);

current
}

0 comments on commit 519705b

Please sign in to comment.