Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move optimizer to create_optimizer #2851

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 1 addition & 38 deletions swift/llm/argument/merge_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from dataclasses import dataclass
from typing import Optional

from swift.utils import get_logger, is_merge_kit_available
from swift.utils import get_logger

logger = get_logger()

Expand All @@ -16,44 +16,7 @@ class MergeArguments:
merge_lora (bool): Flag to indicate if LoRA merging is enabled. Default is False.
safe_serialization(bool): Use safetensors or not, default `True`.
max_shard_size(str): The max size of single shard file.
use_merge_kit (bool): Flag to indicate merge with `mergekit`. Default is False.
instruct_model (Optional[str]): Path or ID of the instruct model. Use when `use_merge_kit` is True.
instruct_model_revision (Optional[str]): Revision of the instruct model. Use when `use_merge_kit` is True.
"""
merge_lora: bool = False
safe_serialization: bool = True
max_shard_size: str = '5GB'

use_merge_kit: bool = False
instruct_model: Optional[str] = None
instruct_model_revision: Optional[str] = None

def __post_init__(self):
if self.use_merge_kit:
assert is_merge_kit_available(), ('please install mergekit by pip install '
'git+https://github.com/arcee-ai/mergekit.git')
logger.info('Important: You are using mergekit, please remember '
'the LoRA should be trained against the base model,'
'and pass its instruct model by --instruct_model xxx when merging')
assert self.instruct_model, 'Please pass in the instruct model'

self.merge_yaml = """
models:
- model: {merged_model}
parameters:
weight: 1
density: 1
- model: {instruct_model}
parameters:
weight: 1
density: 1
merge_method: ties
base_model: {base_model}
parameters:
weight: 1
density: 1
normalize: true
int8_mask: true
tokenizer_source: {merged_model}
dtype: bfloat16
"""
7 changes: 7 additions & 0 deletions swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,13 @@ def __post_init__(self) -> None:
TunerArguments.__post_init__(self)
TorchAccArguments.__post_init__(self)

if self.lorap_lr_ratio:
self.optimizer = 'lorap'
elif self.use_galore:
self.optimizer = 'galore'
elif self.optimizer is None:
self.optimizer = 'default'

if len(self.dataset) == 0:
raise ValueError(f'self.dataset: {self.dataset}, Please input the training dataset.')

Expand Down
24 changes: 0 additions & 24 deletions swift/llm/export/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,6 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False)
origin_device_map = args.device_map
args.device_map = device_map or args.device_map
logger.info(f'merge_device_map: {device_map}')
if args.use_merge_kit:
base_model = args.model
if not os.path.exists(base_model):
base_model = args.hub.download_model(base_model, revision=args.model_revision)
if not os.path.exists(args.instruct_model):
args.instruct_model = args.hub.download_model(
args.instruct_model, revision=args.instruct_model_revision)
args.model = args.instruct_model
model, template = prepare_model_template(args)
logger.info('Merge LoRA...')
Swift.merge_and_unload(model)
Expand All @@ -52,19 +44,3 @@ def merge_lora(args: ExportArguments, device_map=None, replace_if_exists=False)

args.model = output_dir
args.adapters = []
if args.use_merge_kit:
tempdir = tempfile.gettempdir()
mergekit_path = os.path.join(output_dir, 'mergekit')
merge_yaml = args.merge_yaml.replace('{merged_model}', output_dir).replace('{instruct_model}',
args.instruct_model).replace(
'{base_model}', base_model)
try:
yamlfile = os.path.join(tempdir, 'mergekit.yaml')
with open(yamlfile, 'w', encoding='utf-8') as f:
f.write(merge_yaml)
logger.info(f'Merging with config: {merge_yaml}')
os.system(f'mergekit-yaml {yamlfile} {mergekit_path}')
logger.info(f'Merge complete with path: {mergekit_path}')
finally:
if tempdir:
shutil.rmtree(tempdir, ignore_errors=True)
15 changes: 0 additions & 15 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def run(self):
self.train_msg['model_parameter_info'] = model_parameter_info
logger.info(f'model_parameter_info: {model_parameter_info}')

optimizers = self._get_optimizers(train_dataset)

trainer_cls = TrainerFactory.get_trainer_cls(args)
trainer = trainer_cls(
model=self.model,
Expand All @@ -132,7 +130,6 @@ def run(self):
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=self.callbacks,
optimizers=optimizers,
template=self.template,
**self._get_trainer_kwargs(),
)
Expand Down Expand Up @@ -192,18 +189,6 @@ def train(self, trainer):

return self._save_trainer_state(trainer)

def _get_optimizers(self, train_dataset):
args = self.args
if args.lorap_lr_ratio:
optimizer_callback = optimizers_map['lorap']
elif args.use_galore:
optimizer_callback = optimizers_map['galore']
elif args.optimizer is not None:
optimizer_callback = optimizers_map[args.optimizer]
else:
optimizer_callback = optimizers_map['default']
return optimizer_callback(args, self.model, train_dataset)

def _prepare_callbacks(self):
from .callback import DynamicLayerActivationCallback, TrainerAdapterCallback
args = self.args
Expand Down
1 change: 1 addition & 0 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def prepare_model(
gamma_proj=args.galore_gamma_proj,
queue_size=args.galore_queue_size,
)
args.training_args.galore_config = args.galore_config

if args.sequence_parallel_size > 1:
from swift.trainers.xtuner import dispatch_module_xtuner
Expand Down
14 changes: 5 additions & 9 deletions swift/plugin/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int:
if args.max_steps and args.max_steps > 0:
max_steps = args.max_steps
else:
assert not args.streaming
len_dataset = len(dataset)
_, _, world_size, _ = get_dist_setting()
total_train_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * world_size
Expand All @@ -23,17 +22,14 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int:

def create_galore_optimizers(args, model, dataset):
training_steps = calculate_max_steps(args, dataset)
return create_optimizer_and_scheduler(
model,
args.training_args,
args.galore_config,
training_steps,
lr=args.learning_rate,
weight_decay=args.weight_decay)
optimizer, lr_scheduler = create_optimizer_and_scheduler(
model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay)
# trainer cannot serialize galore_config
args.galore_config = None
return optimizer, lr_scheduler


def create_lorap_optimizers(args, model, dataset):
args = args.training_args
optimizer_grouped_parameters = None
if hasattr(model, 'create_optimizer_param_groups'):
# Lora+ parameter groups
Expand Down
5 changes: 4 additions & 1 deletion swift/trainers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments as HfSeq2SeqTrainingArguments

from swift.utils import use_torchacc
from .optimizers.galore import GaLoreConfig


@dataclass
Expand All @@ -28,8 +29,10 @@ class SwiftArgumentsMixin:
fsdp_num: int = 1
acc_steps: int = 1

# Value copied from TrainArguments, Used for external tuners.
# Value copied from TrainArguments
train_type: Optional[str] = None
optimizer: Optional[str] = None
galore_config: Optional[GaLoreConfig] = None

def _fix_gradient_checkpointing(self):
# fix use_reentrant
Expand Down
34 changes: 9 additions & 25 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.data.data_collator import DataCollator
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.modeling_utils import unwrap_model
from transformers.trainer import Trainer, TrainerCallback
from transformers.trainer import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_torch_npu_available

Expand All @@ -32,7 +32,6 @@
from swift.utils import get_logger, is_mp_ddp, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from .arguments import TrainingArguments
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

try:
Expand Down Expand Up @@ -316,32 +315,17 @@ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)

def create_optimizer_and_scheduler(self, num_training_steps: int):
if hasattr(self.args, 'galore_config'):
optimizer, lr_scheduler = create_optimizer_and_scheduler(
self.model,
self.args,
self.args.galore_config,
num_training_steps,
lr=self.args.learning_rate,
weight_decay=self.args.weight_decay)
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
if self.args.optimizer is not None:
from swift.plugin import optimizers_map
optimizer_callback = optimizers_map[self.args.optimizer]
self.optimizer, self.lr_scheduler = optimizer_callback(self.args, self.model, self.train_dataset)
if self.optimizer is None:
self.create_optimizer()
if self.lr_scheduler is None:
self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)
else:
super().create_optimizer_and_scheduler(num_training_steps=num_training_steps)

def create_optimizer(self):

if self.optimizer is None and hasattr(self.model, 'create_optimizer_param_groups'):
# Lora+ parameter groups
optimizer_grouped_parameters = self.model.create_optimizer_param_groups(
lr=self.args.learning_rate, weight_decay=self.args.weight_decay)
if optimizer_grouped_parameters is not None:
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer

return super().create_optimizer()

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.train_sampler_random:
return super()._get_train_sampler()
Expand Down
3 changes: 0 additions & 3 deletions swift/ui/llm_infer/llm_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ class LLMInfer(BaseUI):

is_multimodal = True

deployed = False

sub_ui = [Model, Runtime]

locale_dict = {
Expand Down Expand Up @@ -279,7 +277,6 @@ def deploy_model(cls, *args):
os.system(run_command)
gr.Info(cls.locale('load_alert', cls.lang)['value'])
time.sleep(2)
cls.deployed = True
running_task = Runtime.refresh_tasks(log_file)
return gr.update(open=True), running_task

Expand Down
4 changes: 2 additions & 2 deletions swift/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from .env import (get_dist_setting, get_pai_tensorboard_dir, is_deepspeed_enabled, is_dist, is_dist_ta, is_local_master,
is_master, is_mp, is_mp_ddp, is_pai_training_job, torchacc_trim_graph, use_hf_hub, use_torchacc)
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_merge_kit_available,
is_unsloth_available, is_vllm_available, is_xtuner_available)
from .import_utils import (is_liger_available, is_lmdeploy_available, is_megatron_available, is_unsloth_available,
is_vllm_available, is_xtuner_available)
from .io_utils import (JsonlWriter, append_to_jsonl, download_ms_file, get_file_mm_type, open_jsonl_writer,
read_from_jsonl, write_to_jsonl)
from .logger import get_logger
Expand Down
4 changes: 0 additions & 4 deletions swift/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ def is_vllm_available():
return importlib.util.find_spec('vllm') is not None


def is_merge_kit_available():
return importlib.util.find_spec('mergekit') is not None


def is_lmdeploy_available():
return importlib.util.find_spec('lmdeploy') is not None

Expand Down
Loading