Skip to content

multi-modality model construction support #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
0dae9ef
added model source and type for torchtune flamingo support
Gasoonjia Aug 27, 2024
87397e3
added model source and type for torchtune flamingo support
Gasoonjia Aug 27, 2024
0f61614
grab missing enum
Gasoonjia Aug 27, 2024
d7f3a88
fix ModelArgs init
Gasoonjia Aug 27, 2024
994b148
create init func for ModelArgs for BC
Gasoonjia Aug 28, 2024
d184e68
update pipeline for ModleSource and ModelType
Gasoonjia Aug 28, 2024
6bb2485
Merge branch 'main' of github.com:pytorch/torchchat into flamingo_com…
Gasoonjia Aug 28, 2024
0d8e368
revert lintrunner update on ET
Gasoonjia Aug 28, 2024
6c78850
introduce flamingo modules form torchtune
Gasoonjia Aug 28, 2024
2691bae
back up to move to linux
Gasoonjia Aug 28, 2024
ba960f0
mitigate building issue
Gasoonjia Aug 29, 2024
8b3a684
pass local test
Gasoonjia Aug 30, 2024
880dfe2
merge solved
Gasoonjia Aug 30, 2024
e7fa7b4
structual model builder
Gasoonjia Sep 3, 2024
c179bcb
update torchtune address
Gasoonjia Sep 5, 2024
5ead73b
update install requirement
Gasoonjia Sep 6, 2024
882c336
support new torchtune flamingo component
Gasoonjia Sep 6, 2024
952b8bd
specific version for vision and ao
Gasoonjia Sep 6, 2024
9679a5b
convert installation back and bypass torchtune
Gasoonjia Sep 9, 2024
56006ea
Merge branch 'main' into flamingo_component
Gasoonjia Sep 9, 2024
59337a6
update exportation variable name
Gasoonjia Sep 9, 2024
d0e2974
solve merge confilct
Gasoonjia Sep 10, 2024
758af10
solve Jack's wonderful comments
Gasoonjia Sep 11, 2024
80b5481
remveo extra dot
Gasoonjia Sep 11, 2024
1cc7909
add type.Callable
Gasoonjia Sep 11, 2024
95684d9
fix torchchat typos
Gasoonjia Sep 11, 2024
b96bf05
remove all .DS_Store
Gasoonjia Sep 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dist_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main():
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

config = ModelArgs.from_name(MODEL_NAME).text_transformer_args
config = ModelArgs.from_name(MODEL_NAME).transformer_args['text']
logger.info(f"Chat Model Config: {config}")

tokenizer = _build_chat_tokenizer()
Expand Down
2 changes: 1 addition & 1 deletion distributed/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def apply_tp(
# after we apply TP to the model. Because we don't want to change model code
# when applying TP. We need to have change to ensure KVCache has the correct
# size as k and v.
model.config.text_transformer_args.n_local_heads = model.config.text_transformer_args.n_local_heads // tp_mesh.size()
model.config.transformer_args["text"].n_local_heads = model.config.transformer_args["text"].n_local_heads // tp_mesh.size()

# Apply tensor parallelism to every transformer block
for transformer_block in model.layers:
Expand Down
4 changes: 2 additions & 2 deletions torchchat/cli/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def validate_model(

is_tiktoken = self.is_tiktoken
is_sentencepiece = self.is_sentencepiece
use_tiktoken = model.config.text_transformer_args.use_tiktoken
use_tiktoken = model.config.transformer_args["text"].use_tiktoken

if not (is_tiktoken == use_tiktoken) or not (is_sentencepiece != use_tiktoken):
raise RuntimeError(
Expand Down Expand Up @@ -534,7 +534,7 @@ def _initialize_model(
if builder_args.setup_caches:
with torch.device(builder_args.device):
model.setup_caches(
max_batch_size=1, max_seq_length=max_seq_length or model.config.text_transformer_args.max_seq_length
max_batch_size=1, max_seq_length=max_seq_length or model.config.transformer_args["text"].max_seq_length
)

model.to(dtype=builder_args.precision)
Expand Down
2 changes: 1 addition & 1 deletion torchchat/cli/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def convert_hf_checkpoint(
if model_name is None:
model_name = model_dir.name

config = ModelArgs.from_name(model_name).text_transformer_args
config = ModelArgs.from_name(model_name).transformer_args['text']
print(f"Model config {config.__dict__}")

# Load the json file containing weight mapping
Expand Down
4 changes: 2 additions & 2 deletions torchchat/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def export_for_server(
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device),
)

seq = Dim("seq", min=1, max=model.config.text_transformer_args.max_seq_length)
seq = Dim("seq", min=1, max=model.config.transformer_args["text"].max_seq_length)
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"idx": {1: seq}, "input_pos": {0: seq}}
dynamic_shapes = {"tokens": {1: seq}, "input_pos": {0: seq}}
else:
input = (
torch.tensor([[1]], dtype=torch.int, device=device),
Expand Down
4 changes: 2 additions & 2 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def chat(
self.system_prompt = None
# Set up our max_seq_length
if generator_args.chat_mode:
max_seq_length = self.model.config.text_transformer_args.max_seq_length
max_seq_length = self.model.config.transformer_args["text"].max_seq_length
print(
f"Entering Chat Mode. Will continue chatting back and forth with the language model until the models max context length of {max_seq_length} tokens is hit or until the user says /bye"
)
Expand All @@ -700,7 +700,7 @@ def chat(
else:
max_seq_length = min(
encoded.size(0) + generator_args.max_new_tokens,
self.model.config.text_transformer_args.block_size,
self.model.config.transformer_args["text"].block_size,
)

max_seq_length = (
Expand Down
178 changes: 155 additions & 23 deletions torchchat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
# LICENSE file in the root directory of this source tree.
import json
import os
import warnings

from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, Optional
from typing import Callable, Dict, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -26,8 +28,72 @@

from torchchat.utils.build_utils import find_multiple, get_precision

# bypass the import issue, if any
# TODO: remove this once the torchao is ready on macos
try:
from torchtune.models.flamingo import flamingo_decoder, flamingo_vision_encoder
from torchtune.modules.model_fusion import DeepFusionModel
except:
pass

config_path = Path(f"{str(Path(__file__).parent)}/model_params")

class ModelType(Enum):
TextOnly = "text_only"
Flamingo = "flamingo"

# Type for objects that can generate nn.Module instance
ModuleLike = Union[nn.Module, Callable[..., nn.Module]]

@dataclass
class ModelRecipe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Docstring

"""
The class describes and contains all supported model structures in torchchat.

ModelRecipe represents a model as a collection of Transformer modules and a fusion module,
providing a standardized and centralized way to define and build models in torchchat.
Attributes:
model_type (ModelType):
The type of the model.
modules (Dict[str, ModuleLike]):
A dictionary of ModuleLike modules, where each key is the module name and each
value is a ModuleLike object that generates the transformer.
The names of the Transformer modules should match the corresponding names in the
fusion class and the JSON file holding model hyperparameters.
fusion_class (ModuleLike):
A ModuleLike object that generates a fusion module by taking the constructed modules above.
"""

model_type: ModelType
modules: Dict[str, ModuleLike]
fusion_class: ModuleLike

@classmethod
def _text_only(cls):
return cls(
model_type=ModelType.TextOnly,
modules={'text_transformer': Transformer},
fusion_class=nn.Identity,
)
@classmethod
def _flamingo(cls):
return cls(
model_type=ModelType.Flamingo,
modules={
'encoder': flamingo_vision_encoder,
'decoder': flamingo_decoder
},
fusion_class=DeepFusionModel,
)

@classmethod
def get_recipe(cls, model_type):
if model_type == ModelType.TextOnly:
return cls._text_only()
elif model_type == ModelType.Flamingo:
return cls._flamingo()
else:
raise ValueError(f"Can not find the model recipe for {model_type}")

@dataclass
class TransformerArgs:
Expand Down Expand Up @@ -77,13 +143,33 @@ def from_params(cls, params):
params[_to] = params.pop(_from)
return cls(**params)


@dataclass
class ModelArgs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring

text_transformer_args: TransformerArgs
model_type: ModelType
transformer_args: Dict[str, Union[Dict, TransformerArgs]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this type hint outdated?

Suggested change
transformer_args: Dict[str, Union[Dict, TransformerArgs]]
transformer_args: Dict[str, TransformerArgs]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is up-to-date.
For torchchat native builder, we need to have TransformerArgs.
For torchtune, we need Dict.


def __post_init__(self):
assert self.text_transformer_args is not None
assert type(self.text_transformer_args) == TransformerArgs
def __init__(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
model_type: ModelType = ModelType.TextOnly,
) -> None:
self._sanity_check(transformer_args, model_type)

self.model_type = model_type
if isinstance(transformer_args, TransformerArgs):
assert model_type == ModelType.TextOnly
self.transformer_args = {"text": transformer_args}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.transformer_args = {"text": transformer_args}
assert model_type == ModelType.TextOnly
self.transformer_args = {"text": transformer_args}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make "text" a constant as well, we use it in a lot of places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the different PR if you think it is good. I have some plans to make the configuration more concise and structual.

else:
self.transformer_args = transformer_args

def _sanity_check(
self,
transformer_args: Union[TransformerArgs, Dict[str, TransformerArgs]],
model_type: ModelType,
) -> None:
assert isinstance(model_type, ModelType)
assert isinstance(transformer_args, (TransformerArgs, dict))

@classmethod
def from_params(cls, params_path):
Expand All @@ -92,18 +178,18 @@ def from_params(cls, params_path):

try:
# try to interpret as a single transformer config
text_transformer_args = TransformerArgs.from_params(
loaded_params
)
transformer_args: Dict[str, TransformerArgs] = {}
transformer_args["text"] = TransformerArgs.from_params(loaded_params)
model_type = ModelType.TextOnly
except TypeError:
# try to interpret as a dict of transformer configs
for name, params in loaded_params.items():
if name == "text":
text_transformer_args = TransformerArgs.from_params(params)
else:
raise ValueError(f"Unknown transformer name {name}")
model_type = ModelType(loaded_params["model_type"])

return cls(text_transformer_args)
# Currently only supporting flamingo model
assert model_type == ModelType.Flamingo
transformer_args = {k: v for k, v in loaded_params.items() if k != "model_type"}

return cls(transformer_args, model_type)

@classmethod
def from_table(cls, name: str):
Expand Down Expand Up @@ -181,16 +267,61 @@ def update(self, input_pos, k_val, v_val):


class Model(nn.Module):
"""
The entrance for model construction in torchchat.
"""
def __init__(self, config: ModelArgs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have the legacy text_transformer and new model, let's add a quick description until we unify them later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or link to your other PR where you fix this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will happen adjacent to this PR (3.1 torchtune one), so i'd like to keep as it is here.

super().__init__()
self.config = config
self.text_transformer = Transformer(config.text_transformer_args)

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
return self.text_transformer(idx, input_pos)
# TODO: unify the model init logic
if config.model_type == ModelType.TextOnly:
self.text_transformer = Transformer(config.transformer_args["text"])
else:
self.model = self.build_model()

def setup_caches(self, max_batch_size, max_seq_length):
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
def build_model(self) -> nn.Module:
"""
Builds a model based on the provided configuration.
This method retrieves a ModelRecipe instance corresponding to the specified model type,
constructs the required Transformer modules, and combines them using the fusion class.
Returns:
The constructed model instance.
"""
recipe = ModelRecipe.get_recipe(self.config.model_type)
modules = {}
for name, module_class in recipe.modules.items():
modules[name] = module_class(**self.config.transformer_args[name])

return recipe.fusion_class(**modules)

def forward(self,
tokens: Optional[Tensor] = None,
input_pos: Optional[Tensor] = None,
encoder_input: Optional[Dict[str, Tensor]] = None,
encoder_mask: Optional[Tensor] = None) -> Tensor:

if self.config.model_type == ModelType.TextOnly:
return self.text_transformer(tokens, input_pos)
else:
assert self.config.model_type == ModelType.Flamingo
if input_pos:
warnings.warn("input_pos is not used for Flamingo model. Ignoring it.")
if encoder_input is None:
return self.model(tokens, encoder_mask = encoder_mask)
return self.model(tokens, encoder_input=encoder_input, encoder_mask = encoder_mask)

def setup_caches(self, max_batch_size, max_seq_length=None, dtype=None):
if self.config.model_type == ModelType.TextOnly:
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
else:
assert self.config.model_type == ModelType.Flamingo
if max_seq_length is not None:
warnings.warn("max_seq_length is not used for Flamingo model. Ignoring it.")
self.model.setup_caches(max_batch_size, dtype=dtype)

def reset_caches(self):
assert self.config.model_type == ModelType.Flamingo
self.model.reset_caches()

@classmethod
def from_name(cls, name: str):
Expand Down Expand Up @@ -564,11 +695,11 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# ExecuTorch model components
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

try:
try:
from executorch.extension.pybindings import portable_lib as exec_lib

# ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately.
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa
from executorch.examples.models.llama2.custom_ops import sdpa_with_kv_cache # no-qa

class PTEModel(nn.Module):
def __init__(self, config, path) -> None:
Expand All @@ -589,5 +720,6 @@ def forward(self, x, input_pos):

def setup_caches(self, max_batch_size, max_seq_length):
pass

except:
pass
2 changes: 1 addition & 1 deletion torchchat/usages/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill(
T = prompt.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.text_transformer_args.block_size)
max_seq_length = min(T_new, model.config.transformer_args["text"].block_size)

device, dtype = prompt.device, prompt.dtype
# create an empty tensor of the expected final shape and
Expand Down
4 changes: 2 additions & 2 deletions torchchat/usages/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,11 @@ def __init__(self, *args, **kwargs):

super().__init__(*args, **kwargs)
self.max_seq_length = (
self.model.config.text_transformer_args.max_seq_length
self.model.config.transformer_args["text"].max_seq_length
+ self.speculative_builder_args.speculate_k
+ 1
if self.draft_model is not None
else self.model.config.text_transformer_args.max_seq_length
else self.model.config.transformer_args["text"].max_seq_length
)
# The System fingerprint is a unique identifier for the model and its configuration.
self.system_fingerprint = (
Expand Down
Loading