-
Notifications
You must be signed in to change notification settings - Fork 250
multi-modality model construction support #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1068
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b96bf05 with merge base c272df4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great!! Thanks for making it so easy to review
Mostly nits, but I know this is part of a PR stack we're landing, so we have some leeway
.DS_Store
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might have to manually tell git not to track
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a few DS Stores
torchchat/model.py
Outdated
@dataclass | ||
class ModelRecipe: | ||
model_type: ModelType | ||
modules: dict |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modules: dict | |
modules: Dict[str, Any] |
|
||
|
||
@dataclass | ||
class ModelRecipe: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Docstring
torchchat/model.py
Outdated
fusion_class: torch.nn.Module | ||
|
||
@classmethod | ||
def text_only(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def text_only(cls): | |
def _text_only(cls): |
torchchat/model.py
Outdated
fusion_class=nn.Identity, | ||
) | ||
@classmethod | ||
def flamingo(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def flamingo(cls): | |
def _flamingo(cls): |
|
||
self.model_type = model_type | ||
if isinstance(transformer_args, TransformerArgs): | ||
self.transformer_args = {"text": transformer_args} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make "text" a constant as well, we use it in a lot of places
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe in the different PR if you think it is good. I have some plans to make the configuration more concise and structual.
torchchat/model.py
Outdated
|
||
return cls(text_transformer_args) | ||
# now only support flamingo model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# now only support flamingo model | |
# Currently only supporting flamingo model |
@@ -77,13 +124,32 @@ def from_params(cls, params): | |||
params[_to] = params.pop(_from) | |||
return cls(**params) | |||
|
|||
|
|||
@dataclass | |||
class ModelArgs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring
torchchat/model.py
Outdated
|
||
def setup_caches(self, max_batch_size, max_seq_length): | ||
self.text_transformer.setup_caches(max_batch_size, max_seq_length) | ||
def build_model(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is where all the magic comes together; Let's add a docstring here
@@ -184,13 +250,48 @@ class Model(nn.Module): | |||
def __init__(self, config: ModelArgs) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have the legacy text_transformer
and new model
, let's add a quick description until we unify them later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or link to your other PR where you fix this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That will happen adjacent to this PR (3.1 torchtune one), so i'd like to keep as it is here.
Before landing:
|
dist_run.py
Outdated
@@ -122,7 +122,7 @@ def main(): | |||
gpu_memory_monitor = GPUMemoryMonitor("cuda") | |||
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}") | |||
|
|||
config = ModelArgs.from_name(MODEL_NAME).text_transformer_args | |||
config = ModelArgs.from_name(MODEL_NAME)..transformer_args['text'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Double dot
torchchat/model.py
Outdated
|
||
@dataclass | ||
class ModelRecipe: | ||
""" | ||
A class in TorchChat that describes and contains all supported model structures in TorchChat. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A class in TorchChat that describes and contains all supported model structures in TorchChat. | |
A class in torchchat that describes and contains all supported model structures in torchchat. |
torchchat/model.py
Outdated
A class in TorchChat that describes and contains all supported model structures in TorchChat. | ||
|
||
ModelRecipe represents a model as a collection of Transformer modules and a fusion module, | ||
providing a standardized and centralized way to define and build models in TorchChat. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
providing a standardized and centralized way to define and build models in TorchChat. | |
providing a standardized and centralized way to define and build models in torchchat. |
torchchat/model.py
Outdated
@@ -247,6 +267,9 @@ def update(self, input_pos, k_val, v_val): | |||
|
|||
|
|||
class Model(nn.Module): | |||
""" | |||
The entrance for model construction in tochchat. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The entrance for model construction in tochchat. | |
The entrance for model construction in torchchat. |
This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system.
Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and #1123 (comment) for better structure and llama3.1 support