Skip to content

Add Phi-4-mini-instruct support #12099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Add Phi-4-mini-instruct support #12099

wants to merge 3 commits into from

Conversation

ns3284
Copy link

@ns3284 ns3284 commented Feb 27, 2025

Added new vocab type: gpt-4o
Added Phi3 support for partial_rotary_factor
Added Phi3 support for tie_word_embeddings

@github-actions github-actions bot added the python python script changes label Feb 27, 2025
@ngxson
Copy link
Collaborator

ngxson commented Feb 27, 2025

Few things to note (I'll push a commit tomorrow when I get back to work):

  • In the update.py script, it's better to point to the repo Xenova/gpt-4o
  • Need to add a dedicated KV metadata for partial_rotary_factor to make it more explicit, just to be a bit future-proof here

@ns3284
Copy link
Author

ns3284 commented Feb 27, 2025

Still unsure about the KV metadata part, but pushed updates for the Xenova/gpt-4o.

something like this?

        rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
        rotary_factor = rotary_factor if rotary_factor is not None else 1.0

@Mungert69
Copy link

Thanks I have tested a few different gguf models created with your branch and they seem to be working ok. Posting them to huggingface https://huggingface.co/Mungert/Phi-4-mini-instruct.gguf . Many thanks for getting Phi-4-mini-instruct working

@ngxson
Copy link
Collaborator

ngxson commented Feb 28, 2025

@Mungert69 please don't post gguf on HF before the PR is merging, as there can be more works and your gguf may break after this is finished.

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is also missing tokenizer .inp/.out files.

Since I cannot push to this PR (because you created from your master branch), I will make another PR to replace it. Will keep your commits there so you're still in co-author

@@ -2223,8 +2228,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);

layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this check works but not actually correct, since scaling_type is to calculate attn_factor

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, the else branch is redundant because we know from the conversion script that if rot_pct is not set, then we're sure that n_rot = n_embd / n_head

Other arch like LLM_ARCH_LLAMA does the same thing

@@ -109,6 +109,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "megrez", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Infinigence/Megrez-3B-Instruct"},
{"name": "deepseek-v3", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-V3"},
{"name": "deepseek-r1-qwen", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"},
{"name": "gpt-4o", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Xenova/gpt-4o", },
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not actually work since Xenova/gpt-4o misses config.json, I have to make an exception for it


def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
if self.hparams.get("partial_rotary_factor") is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be written shorter: self.hparams.get("partial_rotary_factor", 1.0)

@ngxson
Copy link
Collaborator

ngxson commented Feb 28, 2025

  • Need to add a dedicated KV metadata for partial_rotary_factor to make it more explicit, just to be a bit future-proof here

Small correction, this is not needed. Other arch like Phi2Model just scale the n_rot accordingly

@ngxson
Copy link
Collaborator

ngxson commented Feb 28, 2025

Close and supersede by #12108

@ngxson ngxson closed this Feb 28, 2025
ngxson added a commit that referenced this pull request Feb 28, 2025
* Added Phi-4-mini-instruct support

* Update regex per ngxson

* Change the vocab base to Xenova/gpt-4o

* fix conversion update script

* no need to check longrope

* minor style fix

* fix python style

---------

Co-authored-by: Nicholas Sparks <nisparks@microsoft.com>
mglambda pushed a commit to mglambda/llama.cpp that referenced this pull request Mar 8, 2025
…2108)

* Added Phi-4-mini-instruct support

* Update regex per ngxson

* Change the vocab base to Xenova/gpt-4o

* fix conversion update script

* no need to check longrope

* minor style fix

* fix python style

---------

Co-authored-by: Nicholas Sparks <nisparks@microsoft.com>
arthw pushed a commit to arthw/llama.cpp that referenced this pull request Mar 19, 2025
…2108)

* Added Phi-4-mini-instruct support

* Update regex per ngxson

* Change the vocab base to Xenova/gpt-4o

* fix conversion update script

* no need to check longrope

* minor style fix

* fix python style

---------

Co-authored-by: Nicholas Sparks <nisparks@microsoft.com>
mostlyuseful pushed a commit to mostlyuseful/llama.cpp that referenced this pull request May 12, 2025
…2108)

* Added Phi-4-mini-instruct support

* Update regex per ngxson

* Change the vocab base to Xenova/gpt-4o

* fix conversion update script

* no need to check longrope

* minor style fix

* fix python style

---------

Co-authored-by: Nicholas Sparks <nisparks@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants