Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support peft0.14 #2587

Merged
merged 5 commits into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/Instruction/命令行参数.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
- 🔥freeze_llm: 冻结LLM. 默认为False. 可用于全参和LoRA
- 🔥target_modules: 指定lora模块, 默认为`all-linear`, 自动寻找除lm_head外的linear并附加tuner. 该参数不限于LoRA
- 🔥target_regex: 指定lora模块的regex表达式. 默认为`None`, 如果该值传入, 则target_modules不生效. 该参数不限于LoRA
- 🔥init_weights: 初始化weights的方法, LoRA可以指定为`true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, Bone可以指定为`true`, `false`, `bat`, 默认值`true`
- modules_to_save: 在已附加tuner后,原模型参与训练和存储的模块,默认为`[]`. 该参数不限于LoRA

#### 全参
Expand All @@ -138,7 +139,6 @@
- 🔥lora_rank: 默认为`8`
- 🔥lora_alpha: 默认为`32`
- lora_dropout: 默认为`0.05`
- 🔥init_lora_weights: 初始化LoRA weights的方法, 可以指定为`true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, 默认值`true`
- lora_bias: 默认为`'none'`, 可以选择的值: 'none', 'all'. 如果你要将bias全都设置为可训练, 你可以设置为`'all'`
- lora_dtype: 指定lora模块的dtype类型. 支持'float16', 'bfloat16', 'float32',不设置默认跟随原模型类型
- 🔥use_dora: 默认为`False`, 是否使用`DoRA`
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Instruction/Command-line-parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ Other important parameters:
- 🔥freeze_llm: Freeze LLM. Default is False. Applicable for full parameters and LoRA.
- 🔥target_modules: Specify the LoRA module, default is `all-linear`, automatically finds linear layers except for lm_head and attaches the tuner. This parameter is not limited to LoRA.
- 🔥target_regex: Specify a regex expression for the LoRA module. Default is `None`, if this value is provided, target_modules does not take effect. This parameter is not limited to LoRA.
- 🔥init_weights: The method of init tuner weights, For lora the accepted values are `true`, `false`, `guassian`, `pissa`, `pissa_niter_[number of iters]`, for bone are `true`, `false`, `bat`, default is `true`
- modules_to_save: After the tuner is attached, the original model's modules used during training and storage, default is `[]`. This parameter is not limited to LoRA.

#### Full Arguments
Expand All @@ -143,7 +144,6 @@ Other important parameters:
- 🔥lora_rank: Default is `8`.
- 🔥lora_alpha: Default is `32`.
- lora_dropout: Default is `0.05`.
- 🔥init_lora_weights: Method to initialize LoRA weights, can be specified as `true`, `false`, `gaussian`, `pissa`, `pissa_niter_[number of iters]`, default is `true`.
- lora_bias: Default is `'none'`, selectable values are: 'none', 'all'. If you want to set all biases as trainable, you can set it to `'all'`.
- lora_dtype: Specify the dtype of the LoRA module. Supports 'float16', 'bfloat16', 'float32', defaults to the original model type.
- 🔥use_dora: Default is `False`, whether to use `DoRA`.
Expand Down
17 changes: 17 additions & 0 deletions examples/train/tuners/bone/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# 17.3GiB
CUDA_VISIBLE_DEVICES=0 \
swift sft \
--model Qwen/Qwen2.5-7B-Instruct \
--train_type bone \
--label_names labels \
--dataset swift/self-cognition#1000 \
--num_train_epochs 1 \
--per_device_train_batch_size 1 \
--learning_rate 1e-4 \
--gradient_accumulation_steps 16 \
--eval_steps 100 \
--save_steps 100 \
--save_total_limit 2 \
--logging_steps 5 \
--model_author swift \
--model_name swift-robot
2 changes: 1 addition & 1 deletion requirements/framework.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ nltk
numpy<2.0
oss2
pandas
peft>=0.11.0,<0.14.0
peft>=0.11.0,<0.15.0
pillow
requests
rouge
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@


def get_supported_tuners():
return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft'} | set(
extra_tuners.keys())
return {'lora', 'full', 'longlora', 'adalora', 'llamapro', 'adapter', 'vera', 'boft', 'fourierft', 'reft', 'bone'
} | set(extra_tuners.keys())


@dataclass
Expand Down
13 changes: 7 additions & 6 deletions swift/llm/argument/tuner_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ class TunerArguments:
lorap_lr_ratio (float): Learning rate ratio for LoRA. Default is None.
use_rslora (bool): Flag to indicate if RSLora is used. Default is False.
use_dora (bool): Flag to indicate if Dora is used. Default is False.
init_lora_weights (str): Initialization method for LoRA weights. Default is 'true'.
Allowed values are 'gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false'.
init_weights (str): Initialization method for weights of supported tuners. Default is 'true'.

fourier_n_frequency (int): Number of frequencies for FourierFT. Default is 2000.
fourier_scaling (float): Scaling factor for FourierFT. Default is 300.0.
Expand Down Expand Up @@ -110,8 +109,10 @@ class TunerArguments:
lorap_lr_ratio: Optional[float] = None
use_rslora: bool = False
use_dora: bool = False
# Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false']
init_lora_weights: str = 'true'
# Lora: Literal['gaussian', 'pissa', 'pissa_niter_[number of iters]', 'olora', 'loftq', 'true', 'false']

# Bone: Literal['bat', 'true', 'false']
init_weights: str = 'true'

# fourierft
fourier_n_frequency: int = 2000
Expand Down Expand Up @@ -181,8 +182,8 @@ class TunerArguments:
use_liger: bool = False

def __post_init__(self):
if isinstance(self.init_lora_weights, str) and self.init_lora_weights.lower() in {'true', 'false'}:
self.init_lora_weights = bool(strtobool(self.init_lora_weights))
if isinstance(self.init_weights, str) and self.init_weights.lower() in {'true', 'false'}:
self.init_weights = bool(strtobool(self.init_weights))
self._init_multimodal_full()
if self.target_regex:
self.target_modules = self.target_regex
Expand Down
3 changes: 3 additions & 0 deletions swift/llm/infer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ def _prepare_pt_engine(args: InferArguments, pt_engine):
pt_engine.processor = processor
else:
pt_engine.model = Swift.from_pretrained(pt_engine.model, args.ckpt_dir, inference_mode=True)
if args.train_type == 'bone':
# Bone has a problem of float32 matmul with bloat16 in `peft==0.14.0`
pt_engine.model.to(pt_engine.model.dtype)


def prepare_pt_engine_template(args: InferArguments, load_model: bool = True, **kwargs) -> Tuple[PtEngine, Template]:
Expand Down
33 changes: 26 additions & 7 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,26 @@
def apply_liger(model_type: str):
from liger_kernel.transformers import (apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral, apply_liger_kernel_to_gemma,
apply_liger_kernel_to_qwen2)
if 'llama3' in model_type:
apply_liger_kernel_to_qwen2, apply_liger_kernel_to_qwen2_vl,
apply_liger_kernel_to_gemma2, apply_liger_kernel_to_phi3,
apply_liger_kernel_to_mllama)
from swift.llm import ModelType
if model_type in (ModelType.llama, ModelType.llama3, ModelType.llama3_1, ModelType.llama3_2):
apply_liger_kernel_to_llama()
elif 'mistral' in model_type:
elif model_type in (ModelType.mistral):
apply_liger_kernel_to_mistral()
elif 'mixtral' in model_type:
elif model_type in (ModelType.mixtral):
apply_liger_kernel_to_mixtral()
elif 'gemma' in model_type:
elif model_type in (ModelType.gemma):
apply_liger_kernel_to_gemma()
elif 'qwen2' in model_type:
elif model_type in (ModelType.gemma2):
apply_liger_kernel_to_qwen2()
elif model_type in (ModelType.phi3):
apply_liger_kernel_to_phi3()
elif model_type in (ModelType.llama3_2_vision):
apply_liger_kernel_to_mllama()
elif model_type in (ModelType.qwen2_vl):
apply_liger_kernel_to_qwen2_vl()
else:
raise ValueError(f'Unsupported liger model_type: {model_type}')

Expand Down Expand Up @@ -111,7 +120,7 @@ def prepare_adapter(args: TrainArguments, model):
'use_rslora': args.use_rslora,
'use_dora': args.use_dora,
'lorap_lr_ratio': args.lorap_lr_ratio,
'init_lora_weights': args.init_lora_weights,
'init_lora_weights': args.init_weights,
}

if args.train_type in ('lora', 'longlora'):
Expand Down Expand Up @@ -224,6 +233,16 @@ def prepare_adapter(args: TrainArguments, model):
)
logger.info(f'reft config: {reft_config}')
model = Swift.prepare_model(model, {'reft': reft_config})
elif args.train_type == 'bone':
# Version loosing
from peft import BoneConfig
bone_config = BoneConfig(
target_modules=target_modules,
r=args.reft_rank,
init_weights=args.init_weights,
)
logger.info(f'bone config: {bone_config}')
model = Swift.prepare_model(model, bone_config)
return model


Expand Down
11 changes: 8 additions & 3 deletions swift/tuners/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def dispatch_bnb_8bit(target: torch.nn.Module, adapter_name: str, module_key: st
eightbit_kwargs = kwargs.copy()
eightbit_kwargs.update({
'has_fp16_weights': target.state.has_fp16_weights,
'memory_efficient_backward': target.state.memory_efficient_backward,
'threshold': target.state.threshold,
'index': target.index,
})
Expand Down Expand Up @@ -590,7 +589,11 @@ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
else:
raise NotImplementedError(f'Requested bias: {bias}, is not implemented.')

def inject_adapter(self, model: nn.Module, adapter_name: str):
def inject_adapter(self,
model: nn.Module,
adapter_name: str,
autocast_adapter_dtype: bool = True,
low_cpu_mem_usage: bool = False):
r"""
Override code:
1. ModulesToSaveWrapper construction method: add module_key=key argument to offload to cpu
Expand Down Expand Up @@ -789,13 +792,15 @@ def _replace_module(self, parent, child_name, new_module, child):
new_module.state = child.state
new_module.to(child.weight.device)

meta = torch.device('meta')
# dispatch to correct device
for name, module in new_module.named_modules():
if (self.prefix in name) or ('ranknum' in name):
weight = (
child.qweight if hasattr(child, 'qweight') else child.W_q if hasattr(child, 'W_q') else
child.weight if hasattr(child, 'weight') else next(child.parameters()))
module.to(weight.device)
if not any(p.device == meta for p in module.parameters()):
module.to(weight.device)

@staticmethod
def _create_new_module(lora_config, adapter_name, target, **kwargs):
Expand Down
15 changes: 10 additions & 5 deletions swift/tuners/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch.nn
import transformers
from modelscope import snapshot_download
from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, IA3Config, IA3Model, LoftQConfig, LoHaConfig, LoKrConfig,
LoraModel, OFTConfig, PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
from peft import (AdaLoraConfig, BOFTConfig, BOFTModel, LoftQConfig, LoHaConfig, LoKrConfig, LoraModel, OFTConfig,
PeftConfig, PeftModel, PeftModelForCausalLM, PeftModelForSeq2SeqLM,
PeftModelForSequenceClassification, PeftModelForTokenClassification, PrefixTuningConfig,
PromptEncoderConfig, PromptLearningConfig, PromptTuningConfig, VeraConfig, VeraModel, get_peft_config,
get_peft_model, get_peft_model_state_dict)
Expand All @@ -28,6 +28,11 @@
except ImportError:
FourierFTModel = None

try:
from peft import BoneModel
except ImportError:
BoneModel = None

logger = get_logger()
dispatchers = []

Expand Down Expand Up @@ -280,11 +285,12 @@ def hot_patch_peft_module():
VeraModel._create_and_replace = _create_and_replace_hook
BOFTModel._create_and_replace_origin = BOFTModel._create_and_replace
BOFTModel._create_and_replace = _create_and_replace_hook
IA3Model._create_and_replace_origin = IA3Model._create_and_replace
IA3Model._create_and_replace = _create_and_replace_hook
if FourierFTModel is not None:
FourierFTModel._create_and_replace_origin = FourierFTModel._create_and_replace
FourierFTModel._create_and_replace = _create_and_replace_hook
if BoneModel is not None:
BoneModel._create_and_replace_origin = BoneModel._create_and_replace
BoneModel._create_and_replace = _create_and_replace_hook

# Support type conversion
def __new_init__(self, model: torch.nn.Module, config: Dict[str, LoraConfig], adapter_name: str):
Expand Down Expand Up @@ -367,7 +373,6 @@ def wrap_module(module):
PromptLearningConfig = wrap_module(PromptLearningConfig)
LoraConfig = wrap_module(LoraConfig)
AdaLoraConfig = wrap_module(AdaLoraConfig)
IA3Config = wrap_module(IA3Config)
LoHaConfig = wrap_module(LoHaConfig)
LoKrConfig = wrap_module(LoKrConfig)
LoftQConfig = wrap_module(LoftQConfig)
Expand Down
4 changes: 2 additions & 2 deletions swift/ui/llm_train/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class LoRA(BaseUI):
'en': 'The dtype of lora parameters'
}
},
'init_lora_weights': {
'init_weights': {
'label': {
'zh': 'lora初始化方法',
'en': 'init lora weights'
Expand Down Expand Up @@ -99,4 +99,4 @@ def do_build_ui(cls, base_tab: Type['BaseUI']):
gr.Textbox(elem_id='lorap_lr_ratio', scale=2)
gr.Checkbox(elem_id='use_rslora', scale=2)
gr.Checkbox(elem_id='use_dora', scale=2)
gr.Textbox(elem_id='init_lora_weights', scale=4)
gr.Textbox(elem_id='init_weights', scale=4)
Loading