Skip to content

Commit

Permalink
Add default LoRA target modules for Mixtral and Mixtral instruct (lud…
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Jan 3, 2024
1 parent 7e34450 commit 74a71e9
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ludwig/schema/llms/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
# Mistral
"mistral-7b": "mistralai/Mistral-7B-v0.1",
"mistral-7b-instruct": "mistralai/Mistral-7B-Instruct-v0.1",
# Mixtral
"mixtral-8x7b": "mistralai/Mixtral-8x7B-v0.1",
"mixtral-8x7b-instruct": "mistralai/Mixtral-8x7B-Instruct-v0.1",
# OPT
"opt-350m": "facebook/opt-350m",
"opt-1.3b": "facebook/opt-1.3b",
Expand Down
21 changes: 21 additions & 0 deletions ludwig/schema/model_types/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,10 @@ def set_llm_parameters(config: "ModelConfig") -> None:
# Set max_new_tokens in generation config to the max sequence length of the output features
_set_generation_max_new_tokens(config)

# HACK(Arnav): Set Mixtral target modules when using LoRA
# GitHub issue: https://github.com/ludwig-ai/ludwig/issues/3853
_set_mixtral_target_modules(config)


def _set_llm_tokenizers(config: "ModelConfig") -> None:
"""Sets the tokenizers for the LLM model to the pretrained model name or path. This ensures that they use the
Expand Down Expand Up @@ -405,6 +409,23 @@ def _set_generation_max_new_tokens(config: "ModelConfig") -> None:
config.generation.max_new_tokens = max_possible_sequence_length


def _set_mixtral_target_modules(config: "ModelConfig") -> None:
"""If the base model is Mixtral 7x8, LoRA is enabled and the target modules are not set, set the target modules
to q_proj and v_proj."""
if config.base_model not in {"mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1"}:
return

if not config.adapter:
return

if config.adapter.type != "lora" or config.adapter.target_modules:
return

logger.info("Setting adapter target modules to ['q_proj', 'v_proj'] for Mixtral 7x8 base model with LoRA adapter.")

config.adapter.target_modules = ["q_proj", "v_proj"]


@DeveloperAPI
def contains_grid_search_parameters(hyperopt_config: HyperoptConfigDict) -> bool:
"""Returns True if any hyperopt parameter in the config is using the grid_search space."""
Expand Down

0 comments on commit 74a71e9

Please sign in to comment.