Skip to content

Commit

Permalink
Move optimizer to create_optimizer (modelscope#2851)
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet authored Jan 3, 2025
1 parent ff9fa00 commit 07f10d2
Show file tree
Hide file tree
Showing 11 changed files with 29 additions and 121 deletions.
39 changes: 1 addition & 38 deletions swift/llm/argument/merge_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Optional

from swift.utils import get_logger, is_merge_kit_available
from swift.utils import get_logger

logger = get_logger()

Expand All @@ -16,44 +16,7 @@ class MergeArguments:
merge_lora (bool): Flag to indicate if LoRA merging is enabled. Default is False.
safe_serialization(bool): Use safetensors or not, default `True`.
max_shard_size(str): The max size of single shard file.
use_merge_kit (bool): Flag to indicate merge with `mergekit`. Default is False.
instruct_model (Optional[str]): Path or ID of the instruct model. Use when `use_merge_kit` is True.
instruct_model_revision (Optional[str]): Revision of the instruct model. Use when `use_merge_kit` is True.
"""
merge_lora: bool = False
safe_serialization: bool = True
max_shard_size: str = '5GB'

use_merge_kit: bool = False
instruct_model: Optional[str] = None
instruct_model_revision: Optional[str] = None

def __post_init__(self):
if self.use_merge_kit:
assert is_merge_kit_available(), ('please install mergekit by pip install '
'git+https://github.com/arcee-ai/mergekit.git')
logger.info('Important: You are using mergekit, please remember '
'the LoRA should be trained against the base model,'
'and pass its instruct model by --instruct_model xxx when merging')
assert self.instruct_model, 'Please pass in the instruct model'

self.merge_yaml = """
models:
- model: {merged_model}
parameters:
weight: 1
density: 1
- model: {instruct_model}
parameters:
weight: 1
density: 1
merge_method: ties
base_model: {base_model}
parameters:
weight: 1
density: 1
normalize: true
int8_mask: true
tokenizer_source: {merged_model}
dtype: bfloat16
"""
7 changes: 7 additions & 0 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def __post_init__(self) -> None:
TunerArguments.__post_init__(self)
TorchAccArguments.__post_init__(self)

if self.lorap_lr_ratio:
self.optimizer = 'lorap'
elif self.use_galore:
self.optimizer = 'galore'
elif self.optimizer is None:
self.optimizer = 'default'

if len(self.dataset) == 0:
raise ValueError(f'self.dataset: {self.dataset}, Please input the training dataset.')

Expand Down
24 changes: 0 additions & 24 deletions swift/llm/export/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False)
origin_device_map = args.device_map
args.device_map = device_map or args.device_map
logger.info(f'merge_device_map: {device_map}')
if args.use_merge_kit:
base_model = args.model
if not os.path.exists(base_model):
base_model = args.hub.download_model(base_model, revision=args.model_revision)
if not os.path.exists(args.instruct_model):
args.instruct_model = args.hub.download_model(
args.instruct_model, revision=args.instruct_model_revision)
args.model = args.instruct_model
model, template = prepare_model_template(args)
logger.info('Merge LoRA...')
Swift.merge_and_unload(model)
Expand All @@ -52,19 +44,3 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False)

args.model = output_dir
args.adapters = []
if args.use_merge_kit:
tempdir = tempfile.gettempdir()
mergekit_path = os.path.join(output_dir, 'mergekit')
merge_yaml = args.merge_yaml.replace('{merged_model}', output_dir).replace('{instruct_model}',
args.instruct_model).replace(
'{base_model}', base_model)
try:
yamlfile = os.path.join(tempdir, 'mergekit.yaml')
with open(yamlfile, 'w', encoding='utf-8') as f:
f.write(merge_yaml)
logger.info(f'Merging with config: {merge_yaml}')
os.system(f'mergekit-yaml {yamlfile} {mergekit_path}')
logger.info(f'Merge complete with path: {mergekit_path}')
finally:
if tempdir:
shutil.rmtree(tempdir, ignore_errors=True)
15 changes: 0 additions & 15 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def run(self):
self.train_msg['model_parameter_info'] = model_parameter_info
logger.info(f'model_parameter_info: {model_parameter_info}')

optimizers = self._get_optimizers(train_dataset)

trainer_cls = TrainerFactory.get_trainer_cls(args)
trainer = trainer_cls(
model=self.model,
Expand All @@ -132,7 +130,6 @@ def run(self):
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=self.callbacks,
optimizers=optimizers,
template=self.template,
**self._get_trainer_kwargs(),
)
Expand Down Expand Up @@ -192,18 +189,6 @@ def train(self, trainer):

return self._save_trainer_state(trainer)

def _get_optimizers(self, train_dataset):
args = self.args
if args.lorap_lr_ratio:
optimizer_callback = optimizers_map['lorap']
elif args.use_galore:
optimizer_callback = optimizers_map['galore']
elif args.optimizer is not None:
optimizer_callback = optimizers_map[args.optimizer]
else:
optimizer_callback = optimizers_map['default']
return optimizer_callback(args, self.model, train_dataset)

def _prepare_callbacks(self):
from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback
args = self.args
Expand Down
1 change: 1 addition & 0 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def prepare_model(
gamma_proj=args.galore_gamma_proj,
queue_size=args.galore_queue_size,
)
args.training_args.galore_config = args.galore_config

if args.sequence_parallel_size > 1:
from swift.trainers.xtuner import dispatch_module_xtuner
Expand Down
14 changes: 5 additions & 9 deletions swift/plugin/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int:
if args.max_steps and args.max_steps > 0:
max_steps = args.max_steps
else:
assert not args.streaming
len_dataset = len(dataset)
_, _, world_size, _ = get_dist_setting()
total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size
Expand All @@ -23,17 +22,14 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int:

def create_galore_optimizers(args, model, dataset):
training_steps = calculate_max_steps(args, dataset)
return create_optimizer_and_scheduler(
model,
args.training_args,
args.galore_config,
training_steps,
lr=args.learning_rate,
weight_decay=args.weight_decay)
optimizer, lr_scheduler = create_optimizer_and_scheduler(
model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay)
# trainer cannot serialize galore_config
args.galore_config = None
return optimizer, lr_scheduler


def create_lorap_optimizers(args, model, dataset):
args = args.training_args
optimizer_grouped_parameters = None
if hasattr(model, 'create_optimizer_param_groups'):
# Lora+ parameter groups
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments

from swift.utils import use_torchacc
from .optimizers.galore import GaLoreConfig


@dataclass
Expand All @@ -28,8 +29,10 @@ class SwiftArgumentsMixin:
fsdp_num: int = 1
acc_steps: int = 1

# Value copied from TrainArguments, Used for external tuners.
# Value copied from TrainArguments
train_type: Optional[str] = None
optimizer: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None

def _fix_gradient_checkpointing(self):
# fix use_reentrant
Expand Down
34 changes: 9 additions & 25 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import Trainer, TrainerCallback
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_torch_npu_available

Expand All @@ -32,7 +32,6 @@
from swift.utils import get_logger, is_mp_ddp, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from .arguments import TrainingArguments
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

try:
Expand Down Expand Up @@ -316,32 +315,17 @@ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)

def create_optimizer_and_scheduler(self, num_training_steps: int):
if hasattr(self.args, 'galore_config'):
optimizer, lr_scheduler = create_optimizer_and_scheduler(
self.model,
self.args,
self.args.galore_config,
num_training_steps,
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if self.args.optimizer is not None:
from swift.plugin import optimizers_map
optimizer_callback = optimizers_map[self.args.optimizer]
self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
if self.optimizer is None:
self.create_optimizer()
if self.lr_scheduler is None:
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
else:
super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)

def create_optimizer(self):

if self.optimizer is None and hasattr(self.model, 'create_optimizer_param_groups'):
# Lora+ parameter groups
optimizer_grouped_parameters = self.model.create_optimizer_param_groups(
lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
if optimizer_grouped_parameters is not None:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer

return super().create_optimizer()

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.train_sampler_random:
return super()._get_train_sampler()
Expand Down
3 changes: 0 additions & 3 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class LLMInfer(BaseUI):

is_multimodal = True

deployed = False

sub_ui = [Model, Runtime]

locale_dict = {
Expand Down Expand Up @@ -279,7 +277,6 @@ def deploy_model(cls, *args):
os.system(run_command)
gr.Info(cls.locale('load_alert', cls.lang)['value'])
time.sleep(2)
cls.deployed = True
running_task = Runtime.refresh_tasks(log_file)
return gr.update(open=True), running_task

Expand Down
4 changes: 2 additions & 2 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from .env import (get_dist_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_dist_ta, is_local_master,
is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, use_hf_hub, use_torchacc)
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_merge_kit_available,
is_unsloth_available, is_vllm_available, is_xtuner_available)
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_unsloth_available,
is_vllm_available, is_xtuner_available)
from .io_utils import (JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, open_jsonl_writer,
read_from_jsonl, write_to_jsonl)
from .logger import get_logger
Expand Down
4 changes: 0 additions & 4 deletions swift/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ def is_vllm_available():
return importlib.util.find_spec('vllm') is not None


def is_merge_kit_available():
return importlib.util.find_spec('mergekit') is not None


def is_lmdeploy_available():
return importlib.util.find_spec('lmdeploy') is not None

Expand Down

0 comments on commit 07f10d2

Please sign in to comment.