Skip to content

Commit

Permalink
fix llava config
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga committed May 11, 2024
1 parent 5da097f commit b033232
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/llmtuner/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def init_adapter(
if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam):
model = model.float()

if model_args.visual_inputs and hasattr(model, "vision_tower"): # freeze vision model
model.vision_tower.requires_grad_(False)

if finetuning_args.finetuning_type == "freeze" and is_trainable:
logger.info("Fine-tuning method: Freeze")
num_layers = (
Expand Down
2 changes: 1 addition & 1 deletion src/llmtuner/model/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def load_model(
"""
init_kwargs = _get_init_kwargs(model_args)
config = load_config(model_args)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable, add_valuehead)
patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)

model = None
lazy_load = False
Expand Down
9 changes: 3 additions & 6 deletions src/llmtuner/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .utils.moe import add_z3_leaf_module, configure_moe
from .utils.quantization import configure_quantization
from .utils.rope import configure_rope
from .utils.valuehead import configure_valuehead, prepare_valuehead_model
from .utils.visual import autocast_projector_dtype
from .utils.valuehead import prepare_valuehead_model
from .utils.visual import autocast_projector_dtype, configure_hidden_size


if TYPE_CHECKING:
Expand All @@ -40,7 +40,6 @@ def patch_config(
model_args: "ModelArguments",
init_kwargs: Dict[str, Any],
is_trainable: bool,
add_valuehead: bool,
) -> None:
if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
Expand All @@ -50,9 +49,7 @@ def patch_config(
configure_longlora(config, model_args, is_trainable)
configure_quantization(config, tokenizer, model_args, init_kwargs)
configure_moe(config, model_args, is_trainable)

if add_valuehead:
configure_valuehead(config)
configure_hidden_size(config)

if model_args.use_cache and not is_trainable:
setattr(config, "use_cache", True)
Expand Down
7 changes: 1 addition & 6 deletions src/llmtuner/model/utils/valuehead.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,14 @@


if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel
from transformers import PreTrainedModel

from ...hparams import ModelArguments


logger = get_logger(__name__)


def configure_valuehead(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.vision_config, "intermediate_size", None))


def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
r"""
Loads value head parameters from Hugging Face Hub or local disk.
Expand Down
9 changes: 7 additions & 2 deletions src/llmtuner/model/utils/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,19 @@


if TYPE_CHECKING:
from transformers import PreTrainedModel
from transformers import PretrainedConfig, PreTrainedModel

from ...hparams import ModelArguments


logger = get_logger(__name__)


def configure_hidden_size(config: "PretrainedConfig") -> None:
if getattr(config, "model_type", None) == "llava":
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))


def autocast_projector_dtype(
model: "PreTrainedModel", model_args: "ModelArguments", mm_projector_name: str = "multi_modal_projector"
) -> None:
Expand All @@ -22,7 +27,7 @@ def _mm_projector_forward_post_hook(
) -> "torch.Tensor":
return output.to(model_args.compute_dtype)

if hasattr(model, mm_projector_name):
if hasattr(model, mm_projector_name) and getattr(model.config, "quantization_method", None):
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)

0 comments on commit b033232

Please sign in to comment.