-
Notifications
You must be signed in to change notification settings - Fork 250
multi-modality model construction support #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0dae9ef
87397e3
0f61614
d7f3a88
994b148
d184e68
6bb2485
0d8e368
6c78850
2691bae
ba960f0
8b3a684
880dfe2
e7fa7b4
c179bcb
5ead73b
882c336
952b8bd
9679a5b
56006ea
59337a6
d0e2974
758af10
80b5481
1cc7909
95684d9
b96bf05
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -5,10 +5,12 @@ | |||||||
# LICENSE file in the root directory of this source tree. | ||||||||
import json | ||||||||
import os | ||||||||
import warnings | ||||||||
|
||||||||
from dataclasses import dataclass | ||||||||
from enum import Enum | ||||||||
from pathlib import Path | ||||||||
from typing import Dict, Optional | ||||||||
from typing import Callable, Dict, Optional, Union | ||||||||
|
||||||||
import torch | ||||||||
import torch.nn as nn | ||||||||
|
@@ -26,8 +28,72 @@ | |||||||
|
||||||||
from torchchat.utils.build_utils import find_multiple, get_precision | ||||||||
|
||||||||
# bypass the import issue, if any | ||||||||
# TODO: remove this once the torchao is ready on macos | ||||||||
try: | ||||||||
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder | ||||||||
from torchtune.modules.model_fusion import DeepFusionModel | ||||||||
except: | ||||||||
pass | ||||||||
|
||||||||
config_path = Path(f"{str(Path(__file__).parent)}/model_params") | ||||||||
|
||||||||
class ModelType(Enum): | ||||||||
TextOnly = "text_only" | ||||||||
Flamingo = "flamingo" | ||||||||
|
||||||||
# Type for objects that can generate nn.Module instance | ||||||||
ModuleLike = Union[nn.Module, Callable[..., nn.Module]] | ||||||||
|
||||||||
@dataclass | ||||||||
class ModelRecipe: | ||||||||
""" | ||||||||
The class describes and contains all supported model structures in torchchat. | ||||||||
|
||||||||
ModelRecipe represents a model as a collection of Transformer modules and a fusion module, | ||||||||
providing a standardized and centralized way to define and build models in torchchat. | ||||||||
Attributes: | ||||||||
model_type (ModelType): | ||||||||
The type of the model. | ||||||||
modules (Dict[str, ModuleLike]): | ||||||||
A dictionary of ModuleLike modules, where each key is the module name and each | ||||||||
value is a ModuleLike object that generates the transformer. | ||||||||
The names of the Transformer modules should match the corresponding names in the | ||||||||
fusion class and the JSON file holding model hyperparameters. | ||||||||
fusion_class (ModuleLike): | ||||||||
A ModuleLike object that generates a fusion module by taking the constructed modules above. | ||||||||
""" | ||||||||
|
||||||||
model_type: ModelType | ||||||||
modules: Dict[str, ModuleLike] | ||||||||
fusion_class: ModuleLike | ||||||||
|
||||||||
@classmethod | ||||||||
def _text_only(cls): | ||||||||
return cls( | ||||||||
model_type=ModelType.TextOnly, | ||||||||
modules={'text_transformer': Transformer}, | ||||||||
fusion_class=nn.Identity, | ||||||||
) | ||||||||
@classmethod | ||||||||
def _flamingo(cls): | ||||||||
return cls( | ||||||||
model_type=ModelType.Flamingo, | ||||||||
modules={ | ||||||||
'encoder': flamingo_vision_encoder, | ||||||||
'decoder': flamingo_decoder | ||||||||
}, | ||||||||
fusion_class=DeepFusionModel, | ||||||||
) | ||||||||
|
||||||||
@classmethod | ||||||||
def get_recipe(cls, model_type): | ||||||||
if model_type == ModelType.TextOnly: | ||||||||
return cls._text_only() | ||||||||
elif model_type == ModelType.Flamingo: | ||||||||
return cls._flamingo() | ||||||||
else: | ||||||||
raise ValueError(f"Can not find the model recipe for {model_type}") | ||||||||
|
||||||||
@dataclass | ||||||||
class TransformerArgs: | ||||||||
|
@@ -77,13 +143,33 @@ def from_params(cls, params): | |||||||
params[_to] = params.pop(_from) | ||||||||
return cls(**params) | ||||||||
|
||||||||
|
||||||||
@dataclass | ||||||||
class ModelArgs: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring |
||||||||
text_transformer_args: TransformerArgs | ||||||||
model_type: ModelType | ||||||||
transformer_args: Dict[str, Union[Dict, TransformerArgs]] | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this type hint outdated?
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it is up-to-date. |
||||||||
|
||||||||
def __post_init__(self): | ||||||||
assert self.text_transformer_args is not None | ||||||||
assert type(self.text_transformer_args) == TransformerArgs | ||||||||
def __init__( | ||||||||
self, | ||||||||
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], | ||||||||
model_type: ModelType = ModelType.TextOnly, | ||||||||
) -> None: | ||||||||
self._sanity_check(transformer_args, model_type) | ||||||||
|
||||||||
self.model_type = model_type | ||||||||
if isinstance(transformer_args, TransformerArgs): | ||||||||
assert model_type == ModelType.TextOnly | ||||||||
self.transformer_args = {"text": transformer_args} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's make "text" a constant as well, we use it in a lot of places There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe in the different PR if you think it is good. I have some plans to make the configuration more concise and structual. |
||||||||
else: | ||||||||
self.transformer_args = transformer_args | ||||||||
|
||||||||
def _sanity_check( | ||||||||
self, | ||||||||
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]], | ||||||||
model_type: ModelType, | ||||||||
) -> None: | ||||||||
assert isinstance(model_type, ModelType) | ||||||||
assert isinstance(transformer_args, (TransformerArgs, dict)) | ||||||||
|
||||||||
@classmethod | ||||||||
def from_params(cls, params_path): | ||||||||
|
@@ -92,18 +178,18 @@ def from_params(cls, params_path): | |||||||
|
||||||||
try: | ||||||||
# try to interpret as a single transformer config | ||||||||
text_transformer_args = TransformerArgs.from_params( | ||||||||
loaded_params | ||||||||
) | ||||||||
transformer_args: Dict[str, TransformerArgs] = {} | ||||||||
transformer_args["text"] = TransformerArgs.from_params(loaded_params) | ||||||||
model_type = ModelType.TextOnly | ||||||||
except TypeError: | ||||||||
# try to interpret as a dict of transformer configs | ||||||||
for name, params in loaded_params.items(): | ||||||||
if name == "text": | ||||||||
text_transformer_args = TransformerArgs.from_params(params) | ||||||||
else: | ||||||||
raise ValueError(f"Unknown transformer name {name}") | ||||||||
model_type = ModelType(loaded_params["model_type"]) | ||||||||
|
||||||||
return cls(text_transformer_args) | ||||||||
# Currently only supporting flamingo model | ||||||||
assert model_type == ModelType.Flamingo | ||||||||
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"} | ||||||||
|
||||||||
return cls(transformer_args, model_type) | ||||||||
|
||||||||
@classmethod | ||||||||
def from_table(cls, name: str): | ||||||||
|
@@ -181,16 +267,61 @@ def update(self, input_pos, k_val, v_val): | |||||||
|
||||||||
|
||||||||
class Model(nn.Module): | ||||||||
""" | ||||||||
The entrance for model construction in torchchat. | ||||||||
""" | ||||||||
def __init__(self, config: ModelArgs) -> None: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have the legacy There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or link to your other PR where you fix this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That will happen adjacent to this PR (3.1 torchtune one), so i'd like to keep as it is here. |
||||||||
super().__init__() | ||||||||
self.config = config | ||||||||
self.text_transformer = Transformer(config.text_transformer_args) | ||||||||
|
||||||||
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: | ||||||||
return self.text_transformer(idx, input_pos) | ||||||||
# TODO: unify the model init logic | ||||||||
if config.model_type == ModelType.TextOnly: | ||||||||
self.text_transformer = Transformer(config.transformer_args["text"]) | ||||||||
else: | ||||||||
self.model = self.build_model() | ||||||||
|
||||||||
def setup_caches(self, max_batch_size, max_seq_length): | ||||||||
self.text_transformer.setup_caches(max_batch_size, max_seq_length) | ||||||||
def build_model(self) -> nn.Module: | ||||||||
""" | ||||||||
Builds a model based on the provided configuration. | ||||||||
This method retrieves a ModelRecipe instance corresponding to the specified model type, | ||||||||
constructs the required Transformer modules, and combines them using the fusion class. | ||||||||
Returns: | ||||||||
The constructed model instance. | ||||||||
""" | ||||||||
recipe = ModelRecipe.get_recipe(self.config.model_type) | ||||||||
modules = {} | ||||||||
for name, module_class in recipe.modules.items(): | ||||||||
modules[name] = module_class(**self.config.transformer_args[name]) | ||||||||
|
||||||||
return recipe.fusion_class(**modules) | ||||||||
|
||||||||
def forward(self, | ||||||||
tokens: Optional[Tensor] = None, | ||||||||
input_pos: Optional[Tensor] = None, | ||||||||
encoder_input: Optional[Dict[str, Tensor]] = None, | ||||||||
encoder_mask: Optional[Tensor] = None) -> Tensor: | ||||||||
|
||||||||
if self.config.model_type == ModelType.TextOnly: | ||||||||
return self.text_transformer(tokens, input_pos) | ||||||||
else: | ||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||
if input_pos: | ||||||||
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.") | ||||||||
if encoder_input is None: | ||||||||
return self.model(tokens, encoder_mask = encoder_mask) | ||||||||
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask) | ||||||||
|
||||||||
def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None): | ||||||||
if self.config.model_type == ModelType.TextOnly: | ||||||||
self.text_transformer.setup_caches(max_batch_size, max_seq_length) | ||||||||
else: | ||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||
if max_seq_length is not None: | ||||||||
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.") | ||||||||
self.model.setup_caches(max_batch_size, dtype=dtype) | ||||||||
|
||||||||
def reset_caches(self): | ||||||||
assert self.config.model_type == ModelType.Flamingo | ||||||||
self.model.reset_caches() | ||||||||
|
||||||||
@classmethod | ||||||||
def from_name(cls, name: str): | ||||||||
|
@@ -564,11 +695,11 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: | |||||||
# ExecuTorch model components | ||||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||||||||
|
||||||||
try: | ||||||||
try: | ||||||||
from executorch.extension.pybindings import portable_lib as exec_lib | ||||||||
|
||||||||
# ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately. | ||||||||
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa | ||||||||
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa | ||||||||
|
||||||||
class PTEModel(nn.Module): | ||||||||
def __init__(self, config, path) -> None: | ||||||||
|
@@ -589,5 +720,6 @@ def forward(self, x, input_pos): | |||||||
|
||||||||
def setup_caches(self, max_batch_size, max_seq_length): | ||||||||
pass | ||||||||
|
||||||||
except: | ||||||||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Docstring