Skip to content

Commit 386aaa7

Browse files
committed
Add support for generating config.json in HF trainer.
1 parent 09ab5ab commit 386aaa7

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/oumi/core/trainers/hf_trainer.py

+2
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def save_model(self, config: TrainingConfig, final: bool = True) -> None:
5353
if is_world_process_zero():
5454
output_dir = config.training.output_dir
5555
self._hf_trainer.save_model(output_dir)
56+
self._hf_trainer.model.config.to_json_file(f"{output_dir}/config.json")
5657
logger.info(f"Model has been saved at {output_dir}.")
5758

5859
def _save_fsdp_model(self, config: TrainingConfig, final: bool = True) -> None:
@@ -74,4 +75,5 @@ def _save_fsdp_model(self, config: TrainingConfig, final: bool = True) -> None:
7475

7576
output_dir = config.training.output_dir
7677
self._hf_trainer.save_model(output_dir)
78+
self._hf_trainer.model.config.to_json_file(f"{output_dir}/config.json")
7779
logger.info(f"Model has been saved at {output_dir}.")

0 commit comments

Comments
 (0)