diff --git a/swift/llm/argument/merge_args.py b/swift/llm/argument/merge_args.py index 8732141d85..fbf21709bc 100644 --- a/swift/llm/argument/merge_args.py +++ b/swift/llm/argument/merge_args.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Optional -from swift.utils import get_logger, is_merge_kit_available +from swift.utils import get_logger logger = get_logger() @@ -16,44 +16,7 @@ class MergeArguments: merge_lora (bool): Flag to indicate if LoRA merging is enabled. Default is False. safe_serialization(bool): Use safetensors or not, default `True`. max_shard_size(str): The max size of single shard file. - use_merge_kit (bool): Flag to indicate merge with `mergekit`. Default is False. - instruct_model (Optional[str]): Path or ID of the instruct model. Use when `use_merge_kit` is True. - instruct_model_revision (Optional[str]): Revision of the instruct model. Use when `use_merge_kit` is True. """ merge_lora: bool = False safe_serialization: bool = True max_shard_size: str = '5GB' - - use_merge_kit: bool = False - instruct_model: Optional[str] = None - instruct_model_revision: Optional[str] = None - - def __post_init__(self): - if self.use_merge_kit: - assert is_merge_kit_available(), ('please install mergekit by pip install ' - 'git+https://github.com/arcee-ai/mergekit.git') - logger.info('Important: You are using mergekit, please remember ' - 'the LoRA should be trained against the base model,' - 'and pass its instruct model by --instruct_model xxx when merging') - assert self.instruct_model, 'Please pass in the instruct model' - - self.merge_yaml = """ -models: - - model: {merged_model} - parameters: - weight: 1 - density: 1 - - model: {instruct_model} - parameters: - weight: 1 - density: 1 -merge_method: ties -base_model: {base_model} -parameters: - weight: 1 - density: 1 - normalize: true - int8_mask: true -tokenizer_source: {merged_model} -dtype: bfloat16 -""" diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index a7c3be22ac..7a0127e3f9 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -136,6 +136,13 @@ def __post_init__(self) -> None: TunerArguments.__post_init__(self) TorchAccArguments.__post_init__(self) + if self.lorap_lr_ratio: + self.optimizer = 'lorap' + elif self.use_galore: + self.optimizer = 'galore' + elif self.optimizer is None: + self.optimizer = 'default' + if len(self.dataset) == 0: raise ValueError(f'self.dataset: {self.dataset}, Please input the training dataset.') diff --git a/swift/llm/export/merge_lora.py b/swift/llm/export/merge_lora.py index a00ef26f61..29f470da05 100644 --- a/swift/llm/export/merge_lora.py +++ b/swift/llm/export/merge_lora.py @@ -25,14 +25,6 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) origin_device_map = args.device_map args.device_map = device_map or args.device_map logger.info(f'merge_device_map: {device_map}') - if args.use_merge_kit: - base_model = args.model - if not os.path.exists(base_model): - base_model = args.hub.download_model(base_model, revision=args.model_revision) - if not os.path.exists(args.instruct_model): - args.instruct_model = args.hub.download_model( - args.instruct_model, revision=args.instruct_model_revision) - args.model = args.instruct_model model, template = prepare_model_template(args) logger.info('Merge LoRA...') Swift.merge_and_unload(model) @@ -52,19 +44,3 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False) args.model = output_dir args.adapters = [] - if args.use_merge_kit: - tempdir = tempfile.gettempdir() - mergekit_path = os.path.join(output_dir, 'mergekit') - merge_yaml = args.merge_yaml.replace('{merged_model}', output_dir).replace('{instruct_model}', - args.instruct_model).replace( - '{base_model}', base_model) - try: - yamlfile = os.path.join(tempdir, 'mergekit.yaml') - with open(yamlfile, 'w', encoding='utf-8') as f: - f.write(merge_yaml) - logger.info(f'Merging with config: {merge_yaml}') - os.system(f'mergekit-yaml {yamlfile} {mergekit_path}') - logger.info(f'Merge complete with path: {mergekit_path}') - finally: - if tempdir: - shutil.rmtree(tempdir, ignore_errors=True) diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index 19b90a7610..1159a2baef 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -122,8 +122,6 @@ def run(self): self.train_msg['model_parameter_info'] = model_parameter_info logger.info(f'model_parameter_info: {model_parameter_info}') - optimizers = self._get_optimizers(train_dataset) - trainer_cls = TrainerFactory.get_trainer_cls(args) trainer = trainer_cls( model=self.model, @@ -132,7 +130,6 @@ def run(self): train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=self.callbacks, - optimizers=optimizers, template=self.template, **self._get_trainer_kwargs(), ) @@ -192,18 +189,6 @@ def train(self, trainer): return self._save_trainer_state(trainer) - def _get_optimizers(self, train_dataset): - args = self.args - if args.lorap_lr_ratio: - optimizer_callback = optimizers_map['lorap'] - elif args.use_galore: - optimizer_callback = optimizers_map['galore'] - elif args.optimizer is not None: - optimizer_callback = optimizers_map[args.optimizer] - else: - optimizer_callback = optimizers_map['default'] - return optimizer_callback(args, self.model, train_dataset) - def _prepare_callbacks(self): from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback args = self.args diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 94a62b61a4..4584bdccae 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -403,6 +403,7 @@ def prepare_model( gamma_proj=args.galore_gamma_proj, queue_size=args.galore_queue_size, ) + args.training_args.galore_config = args.galore_config if args.sequence_parallel_size > 1: from swift.trainers.xtuner import dispatch_module_xtuner diff --git a/swift/plugin/optimizer.py b/swift/plugin/optimizer.py index 2eff2354dd..589d78dae1 100644 --- a/swift/plugin/optimizer.py +++ b/swift/plugin/optimizer.py @@ -11,7 +11,6 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int: if args.max_steps and args.max_steps > 0: max_steps = args.max_steps else: - assert not args.streaming len_dataset = len(dataset) _, _, world_size, _ = get_dist_setting() total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size @@ -23,17 +22,14 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int: def create_galore_optimizers(args, model, dataset): training_steps = calculate_max_steps(args, dataset) - return create_optimizer_and_scheduler( - model, - args.training_args, - args.galore_config, - training_steps, - lr=args.learning_rate, - weight_decay=args.weight_decay) + optimizer, lr_scheduler = create_optimizer_and_scheduler( + model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay) + # trainer cannot serialize galore_config + args.galore_config = None + return optimizer, lr_scheduler def create_lorap_optimizers(args, model, dataset): - args = args.training_args optimizer_grouped_parameters = None if hasattr(model, 'create_optimizer_param_groups'): # Lora+ parameter groups diff --git a/swift/trainers/arguments.py b/swift/trainers/arguments.py index 34ac7c5fe1..a0b78b947f 100644 --- a/swift/trainers/arguments.py +++ b/swift/trainers/arguments.py @@ -10,6 +10,7 @@ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments from swift.utils import use_torchacc +from .optimizers.galore import GaLoreConfig @dataclass @@ -28,8 +29,10 @@ class SwiftArgumentsMixin: fsdp_num: int = 1 acc_steps: int = 1 - # Value copied from TrainArguments, Used for external tuners. + # Value copied from TrainArguments train_type: Optional[str] = None + optimizer: Optional[str] = None + galore_config: Optional[GaLoreConfig] = None def _fix_gradient_checkpointing(self): # fix use_reentrant diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index 495bc92354..d0281bcdfc 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -21,7 +21,7 @@ from transformers.data.data_collator import DataCollator from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import unwrap_model -from transformers.trainer import Trainer, TrainerCallback +from transformers.trainer import TrainerCallback from transformers.trainer_utils import EvalPrediction from transformers.utils import is_torch_npu_available @@ -32,7 +32,6 @@ from swift.utils import get_logger, is_mp_ddp, use_torchacc from swift.utils.torchacc_utils import ta_trim_graph from .arguments import TrainingArguments -from .optimizers.galore import create_optimizer_and_scheduler from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model try: @@ -316,32 +315,17 @@ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs): super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs) def create_optimizer_and_scheduler(self, num_training_steps: int): - if hasattr(self.args, 'galore_config'): - optimizer, lr_scheduler = create_optimizer_and_scheduler( - self.model, - self.args, - self.args.galore_config, - num_training_steps, - lr=self.args.learning_rate, - weight_decay=self.args.weight_decay) - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler + if self.args.optimizer is not None: + from swift.plugin import optimizers_map + optimizer_callback = optimizers_map[self.args.optimizer] + self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset) + if self.optimizer is None: + self.create_optimizer() + if self.lr_scheduler is None: + self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer) else: super().create_optimizer_and_scheduler(num_training_steps=num_training_steps) - def create_optimizer(self): - - if self.optimizer is None and hasattr(self.model, 'create_optimizer_param_groups'): - # Lora+ parameter groups - optimizer_grouped_parameters = self.model.create_optimizer_param_groups( - lr=self.args.learning_rate, weight_decay=self.args.weight_decay) - if optimizer_grouped_parameters is not None: - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) - self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - return self.optimizer - - return super().create_optimizer() - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: if self.args.train_sampler_random: return super()._get_train_sampler() diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index 4e5452b2e9..1a5e9548ca 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -29,8 +29,6 @@ class LLMInfer(BaseUI): is_multimodal = True - deployed = False - sub_ui = [Model, Runtime] locale_dict = { @@ -279,7 +277,6 @@ def deploy_model(cls, *args): os.system(run_command) gr.Info(cls.locale('load_alert', cls.lang)['value']) time.sleep(2) - cls.deployed = True running_task = Runtime.refresh_tasks(log_file) return gr.update(open=True), running_task diff --git a/swift/utils/__init__.py b/swift/utils/__init__.py index ec0fee03c0..d7ca019b60 100644 --- a/swift/utils/__init__.py +++ b/swift/utils/__init__.py @@ -2,8 +2,8 @@ from .env import (get_dist_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_dist_ta, is_local_master, is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, use_hf_hub, use_torchacc) -from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_merge_kit_available, - is_unsloth_available, is_vllm_available, is_xtuner_available) +from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_unsloth_available, + is_vllm_available, is_xtuner_available) from .io_utils import (JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, open_jsonl_writer, read_from_jsonl, write_to_jsonl) from .logger import get_logger diff --git a/swift/utils/import_utils.py b/swift/utils/import_utils.py index 444480f716..15eef546e1 100644 --- a/swift/utils/import_utils.py +++ b/swift/utils/import_utils.py @@ -16,10 +16,6 @@ def is_vllm_available(): return importlib.util.find_spec('vllm') is not None -def is_merge_kit_available(): - return importlib.util.find_spec('mergekit') is not None - - def is_lmdeploy_available(): return importlib.util.find_spec('lmdeploy') is not None