Skip to content

Commit

Permalink
add model_init_kwargs to training_args (#1787)
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif authored Jul 3, 2024
1 parent cd85b14 commit b6af2ed
Showing 1 changed file with 3 additions and 8 deletions.
11 changes: 3 additions & 8 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,19 @@
console = Console()

################
# Model & Tokenizer
# Model init kwargs & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
torch_dtype=model_config.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

Expand Down Expand Up @@ -138,7 +134,6 @@
with init_context:
trainer = SFTTrainer(
model=model_config.model_name_or_path,
model_init_kwargs=model_kwargs,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand Down

0 comments on commit b6af2ed

Please sign in to comment.