Skip to content

Commit

Permalink
refactor model_dtype, fix PPO trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed Oct 11, 2023
1 parent 5310e4d commit 2818af0
Show file tree
Hide file tree
Showing 10 changed files with 103 additions and 118 deletions.
28 changes: 28 additions & 0 deletions src/llmtuner/extras/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
from typing import TYPE_CHECKING, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList

try:
from transformers.utils import (
is_torch_bf16_cpu_available,
is_torch_bf16_gpu_available,
is_torch_cuda_available,
is_torch_npu_available
)
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
except ImportError:
_is_fp16_available = torch.cuda.is_available()
_is_bf16_available = torch.cuda.is_bf16_supported()

if TYPE_CHECKING:
from transformers.modeling_utils import PreTrainedModel

Expand Down Expand Up @@ -49,7 +62,22 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
return trainable_params, all_param


def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
r"""
Infers the optimal dtype according to the model_dtype and device compatibility.
"""
if _is_bf16_available and model_dtype == torch.bfloat16:
return torch.bfloat16
elif _is_fp16_available:
return torch.float16
else:
return torch.float32


def get_logits_processor() -> LogitsProcessorList:
r"""
Gets logits processor that removes NaN and Inf logits.
"""
logits_processor = LogitsProcessorList()
logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor
Expand Down
8 changes: 4 additions & 4 deletions src/llmtuner/extras/patches/llama_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ def forward(
input_dtype = query_states.dtype
if input_dtype == torch.float32:
logger.warning_once("The input hidden states seems to be silently casted in float32.")
query_states = query_states.to(torch.float16)
key_states = key_states.to(torch.float16)
value_states = value_states.to(torch.float16)
query_states = query_states.to(self.config.torch_dtype)
key_states = key_states.to(self.config.torch_dtype)
value_states = value_states.to(self.config.torch_dtype)

if getattr(self, "num_key_value_groups"):
if getattr(self, "num_key_value_groups", None):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

Expand Down
6 changes: 3 additions & 3 deletions src/llmtuner/hparams/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class ModelArguments:
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field(
default="auto",
metadata={"help": "Data type of the layer norm weights."}
upcast_layernorm: Optional[bool] = field(
default=False,
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
)

def __post_init__(self):
Expand Down
21 changes: 12 additions & 9 deletions src/llmtuner/tuner/core/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from transformers.deepspeed import is_deepspeed_zero3_enabled

from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
from llmtuner.extras.patches import llama_patch as LlamaPatches
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
Expand Down Expand Up @@ -86,11 +86,17 @@ def load_model_and_tokenizer(
if getattr(config, "model_type", None) == "chatglm":
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)

# Set model dtype
if model_args.compute_dtype is not None:
setattr(config, "torch_dtype", model_args.compute_dtype)
else: # priority: bf16 > fp16 > fp32
optim_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
setattr(config, "torch_dtype", optim_dtype)

# Fix config (for Qwen)
if getattr(config, "model_type", None) == "qwen":
setattr(config, "fp16", model_args.compute_dtype == torch.float16)
setattr(config, "bf16", model_args.compute_dtype == torch.bfloat16)
setattr(config, "fp32", model_args.compute_dtype == torch.float32)
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)

# Set RoPE scaling
if model_args.rope_scaling is not None:
Expand Down Expand Up @@ -131,9 +137,7 @@ def load_model_and_tokenizer(
if model_args.flash_attn:
if getattr(config, "model_type", None) == "llama":
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
LlamaModule.LlamaModel._prepare_decoder_attention_mask = (
LlamaPatches._prepare_decoder_attention_mask
)
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
logger.info("Using FlashAttention-2 for faster training and inference.")
elif getattr(config, "model_type", None) == "qwen":
logger.info("Qwen models automatically enable FlashAttention if installed.")
Expand Down Expand Up @@ -180,7 +184,6 @@ def load_model_and_tokenizer(
model = AutoModelForCausalLM.from_pretrained(
model_to_load,
config=config,
torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs
)
Expand All @@ -203,7 +206,7 @@ def load_model_and_tokenizer(

# Initialize adapters
if is_trainable:
model = prepare_model_for_training(model, model_args.layernorm_dtype, finetuning_args.finetuning_type)
model = prepare_model_for_training(model, model_args.upcast_layernorm, finetuning_args.finetuning_type)
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
model = model.train() if is_trainable else model.eval()

Expand Down
55 changes: 7 additions & 48 deletions src/llmtuner/tuner/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,6 @@
from transformers.utils.versions import require_version
from transformers.trainer_utils import get_last_checkpoint

try:
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available, is_torch_cuda_available
is_fp16_available = is_torch_cuda_available()
is_bf16_available = is_torch_bf16_gpu_available()
is_npu_available = is_torch_npu_available()
except ImportError:
is_fp16_available = torch.cuda.is_available()
is_bf16_available = torch.cuda.is_bf16_supported()
is_npu_available = False

from llmtuner.extras.logging import get_logger
from llmtuner.hparams import (
ModelArguments,
Expand All @@ -31,17 +21,6 @@
logger = get_logger(__name__)


def _infer_dtype() -> torch.dtype:
if is_npu_available:
return torch.float16
elif is_bf16_available:
return torch.bfloat16
elif is_fp16_available:
return torch.float16
else:
return torch.float32


def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None:
return parser.parse_dict(args)
Expand Down Expand Up @@ -178,12 +157,15 @@ def get_train_args(
if not finetuning_args.resume_lora_training:
raise ValueError("Quantized model cannot create new LoRA weight. Merge them first.")

if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")

if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable mixed precision training.")

if (not training_args.do_train) and model_args.quantization_bit is not None:
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")

# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
Expand All @@ -206,10 +188,9 @@ def get_train_args(
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
require_version("transformers>=4.31.0", "Resuming training requires transformers>=4.31.0.")
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")

if last_checkpoint is not None:
training_args_dict = training_args.to_dict()
Expand All @@ -220,26 +201,7 @@ def get_train_args(
)

# postprocess model_args
if training_args.bf16:
if not is_bf16_available:
raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16
elif training_args.fp16:
model_args.compute_dtype = torch.float16
else:
model_args.compute_dtype = _infer_dtype()

if model_args.layernorm_dtype == "bf16":
if not is_bf16_available:
raise ValueError("Current device does not support bf16 type.")
model_args.layernorm_dtype = torch.bfloat16
elif model_args.layernorm_dtype == "fp16":
model_args.layernorm_dtype = torch.float16
elif model_args.layernorm_dtype == "fp32":
model_args.layernorm_dtype = torch.float32
else:
model_args.layernorm_dtype = model_args.compute_dtype

model_args.compute_dtype = torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
model_args.model_max_length = data_args.cutoff_len

# Log on each process the small summary:
Expand Down Expand Up @@ -278,7 +240,4 @@ def get_infer_args(
if model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint. Merge them first.")

# auto-detect cuda capability
model_args.compute_dtype = _infer_dtype()

return model_args, data_args, finetuning_args, generating_args
11 changes: 6 additions & 5 deletions src/llmtuner/tuner/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def find_all_linear_modules(

def prepare_model_for_training(
model: "PreTrainedModel",
layernorm_dtype: torch.dtype,
upcast_layernorm: bool,
finetuning_type: str,
output_layer_name: Optional[str] = "lm_head",
use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel":
r"""
Includes:
Expand All @@ -44,9 +44,10 @@ def prepare_model_for_training(
(3) upcast the lm_head to fp32
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
"""
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(layernorm_dtype)
if upcast_layernorm:
for name, param in model.named_parameters():
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
param.data = param.data.to(torch.float32)

if use_gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"):
Expand Down
50 changes: 23 additions & 27 deletions src/llmtuner/tuner/ppo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
from trl import PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits

from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model

if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.hparams import GeneratingArguments
from llmtuner.hparams import ModelArguments, GeneratingArguments


logger = get_logger(__name__)
Expand All @@ -30,22 +31,27 @@ class CustomPPOTrainer(PPOTrainer, Trainer):

def __init__(
self,
model_args: "ModelArguments",
training_args: "Seq2SeqTrainingArguments",
generating_args: "GeneratingArguments",
callbacks: List["TrainerCallback"],
compute_dtype: torch.dtype,
**kwargs
):
PPOTrainer.__init__(self, **kwargs)
if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None:
raise ValueError("PPOTrainer is incompatible with DeepSpeed.")

self.args = training_args
self.generating_args = generating_args
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
self.compute_dtype = compute_dtype
self.model_args = model_args
self.generation_config = GenerationConfig(
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
**generating_args.to_dict()
)
self.state = TrainerState()
self.control = TrainerControl()
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)

def ppo_train(self) -> None:
r"""
Expand Down Expand Up @@ -74,13 +80,6 @@ def ppo_train(self) -> None:
logger.info(f" Total optimization steps = {max_steps}")
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")

# Keyword arguments for `model.generate`
generating_args = self.generating_args.to_dict()
generating_args.update(dict(
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
pad_token_id=self.tokenizer.pad_token_id
))

unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader)
steps_trained = 0
Expand All @@ -98,7 +97,7 @@ def ppo_train(self) -> None:
self.model.eval()

# Get inputs
queries, responses = self.get_inputs(batch, generating_args)
queries, responses = self.get_inputs(batch)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model)

Expand Down Expand Up @@ -159,27 +158,24 @@ def ppo_train(self) -> None:
)

@torch.no_grad()
def get_inputs(
self,
batch: Dict[str, torch.Tensor],
generating_args: Dict[str, Any]
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r"""
Generates model's responses given queries.
"""
gen_kwargs = dict(
generation_config=GenerationConfig(**generating_args),
if self.model_args.upcast_layernorm:
layernorm_params = dump_layernorm(self.model)

unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(
generation_config=self.generation_config,
logits_processor=get_logits_processor(),
**batch
)

input_ids = batch["input_ids"]
self.model, layer_norm_params = cast_layernorm_dtype(self.model, self.compute_dtype)
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**gen_kwargs)
self.model, _ = cast_layernorm_dtype(self.model, self.compute_dtype, layer_norm_params)
query, response = input_ids.detach().cpu(), response[:, input_ids.size(-1):].detach().cpu()
if self.model_args.upcast_layernorm:
restore_layernorm(self.model, layernorm_params)

query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
queries, responses = [], []
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
Expand Down
Loading

0 comments on commit 2818af0

Please sign in to comment.