Skip to content

Commit

Permalink
remove repeat from mpt
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 authored and AmineDiro committed Aug 15, 2023
1 parent 519705b commit 43dade0
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions crates/models/mpt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,7 @@ impl KnownModel for Mpt {
ctx0.use_scratch(builder.get_scratch(0));

let mut current = ctx0.op_norm(&input_layer);
current = ctx0.op_mul(
&ctx0.op_repeat(&self.layers[il].norm_1_weight, &current),
&current,
);
current = ctx0.op_mul(&current, &self.layers[il].norm_1_weight);

current = ctx0.op_mul_mat(&self.layers[il].c_attn_wqkv_weight, &current);

Expand Down Expand Up @@ -222,10 +219,7 @@ impl KnownModel for Mpt {
ctx0.use_scratch(builder.get_scratch(1));

current = ctx0.op_norm(&input_layer);
current = ctx0.op_mul(
&ctx0.op_repeat(&self.layers[il].norm_2_weight, &current),
&current,
);
current = ctx0.op_mul(&current, &self.layers[il].norm_2_weight);

current = ctx0.op_mul_mat(&self.layers[il].ffn_up_proj, &current);

Expand All @@ -242,7 +236,7 @@ impl KnownModel for Mpt {

// norm
input_layer = ctx0.op_norm(&input_layer);
input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
input_layer = ctx0.op_mul(&input_layer, &self.norm);

let embeddings_tensor: ggml::Tensor = input_layer.share();

Expand Down

0 comments on commit 43dade0

Please sign in to comment.