We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
pip install git+https://github.com/huggingface/transformers.git pip install tokenizers==0.20.0 pip install accelerate==0.34.2 pip install git+https://github.com/huggingface/trl.git pip install datasets==3.0.1 pip install huggingface_hub==0.25.1 pip install peft==0.13.0 pip install databricks-cli==0.18.0 pip install bitsandbytes==0.44.1 pip install flash-attn==2.6.3 --no-build-isolation
examples
def qlora_gkd_train(): import datasets import torch import transformers from trl import ( GKDConfig, GKDTrainer, LogCompletionsCallback, ) from peft import LoraConfig, TaskType, prepare_model_for_kbit_training import json with open('/local_disk0/training_config.json') as f: training_config = json.load(f) # # testing memory usage for batch size training_config['max_steps'] = 10 # training_config['per_device_train_batch_size'] = 32 print(json.dumps(training_config, indent=4)) print("loading tokenizer") tokenizer = transformers.AutoTokenizer.from_pretrained( training_config['teacher_model_name_or_path'], padding_side="left", truncation_side="left", ) tokenizer.pad_token = tokenizer.eos_token print("loading dataset") train_dataset = datasets.load_from_disk('/local_disk0/train') # Model torch_dtype = torch.bfloat16 quant_storage_dtype = torch.bfloat16 quantization_config = transformers.BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch_dtype, bnb_4bit_quant_storage=quant_storage_dtype, ) print("loading teacher model") teacher_model = transformers.AutoModelForCausalLM.from_pretrained( training_config['student_model_name_or_path'], quantization_config=quantization_config, attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2" torch_dtype=quant_storage_dtype, device_map = "auto" ) teacher_model = prepare_model_for_kbit_training(teacher_model) print("create student config") student_model_kwargs = dict( trust_remote_code=True, attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2" torch_dtype=quant_storage_dtype, use_cache=training_config['gradient_checkpointing'], device_map="auto", # quantization_config=quantization_config, ) print("create student config") student_model_kwargs = dict( trust_remote_code=True, attn_implementation="flash_attention_2", # use sdpa, alternatively use "flash_attention_2" torch_dtype=quant_storage_dtype, use_cache=training_config['gradient_checkpointing'], device_map="auto", # quantization_config=quantization_config, ) lora_config = LoraConfig( r=training_config['lora_r'], # target_modules="all-linear", target_modules=["q_proj", "k_proj", "v_proj"], task_type=TaskType.CAUSAL_LM, lora_alpha=training_config['lora_alpha'], lora_dropout=0.05 ) training_arguments = GKDConfig( model_init_kwargs = student_model_kwargs, save_strategy='epoch', report_to='mlflow', # save_steps=training_config['save_steps'], ddp_find_unused_parameters=False, gradient_checkpointing=training_config['gradient_checkpointing'], per_device_train_batch_size=training_config['per_device_train_batch_size'], gradient_accumulation_steps=training_config['gradient_accumulation_steps'], num_train_epochs=training_config['num_train_epochs'], learning_rate=training_config['learning_rate'], warmup_ratio=training_config['warmup_ratio'], lr_scheduler_type="cosine", bf16=True, max_steps=training_config['max_steps'], logging_steps=training_config['logging_steps'], output_dir=training_config['output_dir'], gradient_checkpointing_kwargs={'use_reentrant':False}, max_seq_length=training_config['max_seq_len'], use_liger=training_config['use_liger'], # optim="paged_adamw_8bit", dataset_text_field='prompt', packing=False, # # gkd params temperature=0.9, max_new_tokens=1024, ) print("start training") trainer = GKDTrainer( model=training_config['student_model_name_or_path'], teacher_model=teacher_model, args=training_arguments, train_dataset=train_dataset, processing_class=tokenizer, peft_config=lora_config, ) if training_config['resume']: trainer.train(resume_from_checkpoint=True) else: trainer.train() os.environ['ACCELERATE_BYPASS_DEVICE_MAP'] = "true" qlora_gkd_train()
{ "teacher_model_name_or_path": "/local_disk0/meta-llama/Llama-3.1-70B-Instruct", "student_model_name_or_path": "/local_disk0/meta-llama/Llama-3.2-3B-Instruct", "learning_rate": 1e-05, "per_device_train_batch_size": 4, "gradient_accumulation_steps": 1, "logging_steps": 1, "num_train_epochs": 15, "gradient_checkpointing": true, "use_peft": true, "lora_r": 64, "lora_alpha": 16, "max_seq_len": 1382, "use_liger": false, "warmup_ratio": 0.1, "resume": false, "max_steps": -1 }
Im seeing 0 loss and grad norm. Is this expected?
No 0 loss?
The text was updated successfully, but these errors were encountered:
followup from #2215
Sorry, something went wrong.
@nivibilla what are the keys in your dataset, as currently the datacollator also checks if there is a prompt key to get the prompts only: https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L265
prompt
I just have the prompt column with the name prompt
any update on this @nivibilla ? I suspect its the data issue?
kashif
No branches or pull requests
System Info
Information
Tasks
examples
folderReproduction
Im seeing 0 loss and grad norm. Is this expected?
Expected behavior
No 0 loss?
The text was updated successfully, but these errors were encountered: