Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
4b9d824
fix: Remove instruct as it should always be true and correct num_experts
KyleMylonakisProtopia Dec 16, 2025
1cbae63
docs: improve docstring
KyleMylonakisProtopia Dec 16, 2025
9a8d05c
fix: remove unused argument
KyleMylonakisProtopia Dec 16, 2025
2544cd6
fix: Only add extra special tokens if they exist
KyleMylonakisProtopia Dec 16, 2025
28bb4da
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 16, 2025
82ef0a7
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 16, 2025
87e4b34
fix: prefer .jinja extension for chat template
KyleMylonakisProtopia Dec 16, 2025
c808092
Merge remote-tracking branch 'refs/remotes/origin/fix_conversion_scri…
KyleMylonakisProtopia Dec 16, 2025
527681d
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 16, 2025
f536461
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 17, 2025
67dc5f2
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 17, 2025
3446ab9
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 17, 2025
ffbb71a
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 18, 2025
98f4ff0
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 18, 2025
1e01438
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 18, 2025
f85b5d3
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 18, 2025
e48a612
Merge branch 'main' into fix_conversion_script
KyleMylonakisProtopia Dec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/transformers/convert_slow_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,9 +1897,10 @@ def converted(self) -> Tokenizer:
)
tokenizer.decoder = decoders.ByteLevel()

tokenizer.add_special_tokens(
[AddedToken(token, normalized=False, special=True) for token in self.extra_special_tokens]
)
if self.extra_special_tokens is not None:
tokenizer.add_special_tokens(
[AddedToken(token, normalized=False, special=True) for token in self.extra_special_tokens]
)
Comment on lines -1900 to +1903
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need changes in the core code! cc @itazap @ArthurZucker before I can approve this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It hard-crashes otherwise because self.extra_special_tokens can be None, and is in the conversion script.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping for @itazap and @ArthurZucker feedback.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! How can we get this over the line? Would love to see this change in Transformers 5.0.0 release.


tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)

Expand Down
63 changes: 24 additions & 39 deletions src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,19 @@ def convert_moe_packed_tensors(
def write_model(
model_path,
input_base_path,
instruct=False,
mxfp4=False,
):
os.makedirs(model_path, exist_ok=True)
eos_token_id = 199999 if not instruct else 200002
eos_token_id = 200002
pad_token_id = 199999

original_config = json.loads((Path(input_base_path) / "config.json").read_text())

num_local_experts = original_config.pop("num_experts")
# GPT OSS Models are distributed with either num_experts or num_local_experts depending whether the original subfolder
# or the root folder is used.
num_local_experts = original_config.get("num_experts") or original_config.get("num_local_experts")
if num_local_experts is None:
raise ValueError("num_local_experts or num_experts must be specified in the config.")

# Handle both old and new config formats for rope_parameters
if "rope_parameters" in original_config:
Expand Down Expand Up @@ -280,17 +283,16 @@ def write_model(
print("Model reloaded successfully.")

# generation config
if instruct:
print("Saving generation config...")
generation_config = GenerationConfig(
bos_token_id=199998, # <|startoftext|>
do_sample=True,
eos_token_id=[200002, 199999], # <|return|>, <|endoftext|>
pad_token_id=199999, # <|endoftext|>
temperature=1.0,
top_p=1.0,
)
generation_config.save_pretrained(model_path)
print("Saving generation config...")
generation_config = GenerationConfig(
bos_token_id=199998, # <|startoftext|>
do_sample=True,
eos_token_id=[200002, 199999], # <|return|>, <|endoftext|>
pad_token_id=199999, # <|endoftext|>
temperature=1.0,
top_p=1.0,
)
generation_config.save_pretrained(model_path)


def save_sharded_model(state_dict, model_path):
Expand Down Expand Up @@ -431,7 +433,7 @@ def __init__(
)


def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False):
def write_tokenizer(tokenizer_path: str, save_dir: str):
# Updated Harmony chat template
chat_template = """{#-
In addition to the normal inputs of `messages` and `tools`, this template also accepts the
Expand Down Expand Up @@ -768,41 +770,26 @@ def write_tokenizer(tokenizer_path: str, save_dir: str, instruct: bool = False):
converter = GptOssConverter(
vocab_file=tokenizer_path,
model_max_length=None,
chat_template=chat_template if instruct else None,
chat_template=chat_template,
)
tokenizer = converter.tokenizer
tokenizer.save_pretrained(save_dir)

if instruct:
print("Saving chat template...")
chat_template_path = os.path.join(save_dir, "chat_template.json")
with open(chat_template_path, "w") as f:
json.dump({"chat_template": chat_template}, f, indent=2)
print("Saving chat template...")
chat_template_path = os.path.join(save_dir, "chat_template.jinja")
with open(chat_template_path, "w") as f:
json.dump({"chat_template": chat_template}, f, indent=2)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
default="/fsx/mohamed/oai-hf/tests/120b",
help="Location of LLaMA weights, which contains tokenizer.model and model folders",
help="Location of `./original` subfolder of the GPT OSS model repo.",
)
parser.add_argument(
"--output_dir",
default="/fsx/mohamed/oai-hf/tests/120b_converted_packed",
help="Location to write HF model and tokenizer",
)
parser.add_argument(
"--special_tokens",
default=None,
type=list[str],
help="The list of special tokens that should be added to the ",
)

parser.add_argument(
"--instruct",
action="store_true",
help="Whether the model is an instruct model",
help="Location to write the converted HF model and tokenizer",
)

# Only specify this if you want to use the model with mxfp4 quantization
Expand All @@ -820,14 +807,12 @@ def main():
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
instruct=args.instruct,
mxfp4=args.mxfp4,
)

write_tokenizer(
tokenizer_path="o200k_base",
save_dir=args.output_dir,
instruct=args.instruct,
)


Expand Down