diff --git a/examples/scripts/sft.py b/examples/scripts/sft.py index 1a60411e80..38166264d6 100644 --- a/examples/scripts/sft.py +++ b/examples/scripts/sft.py @@ -94,23 +94,19 @@ console = Console() ################ - # Model & Tokenizer + # Model init kwargs & Tokenizer ################ - torch_dtype = ( - model_config.torch_dtype - if model_config.torch_dtype in ["auto", None] - else getattr(torch, model_config.torch_dtype) - ) quantization_config = get_quantization_config(model_config) model_kwargs = dict( revision=model_config.model_revision, trust_remote_code=model_config.trust_remote_code, attn_implementation=model_config.attn_implementation, - torch_dtype=torch_dtype, + torch_dtype=model_config.torch_dtype, use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) + training_args.model_init_kwargs = model_kwargs tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True) tokenizer.pad_token = tokenizer.eos_token @@ -138,7 +134,6 @@ with init_context: trainer = SFTTrainer( model=model_config.model_name_or_path, - model_init_kwargs=model_kwargs, args=training_args, train_dataset=train_dataset, eval_dataset=eval_dataset,