Skip to content

multi-modality model construction support #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Sep 11, 2024
Merged

multi-modality model construction support #1068

merged 27 commits into from
Sep 11, 2024

Conversation

Gasoonjia
Copy link
Contributor

@Gasoonjia Gasoonjia commented Aug 27, 2024

This PR makes torchchat support multi-modality model definition and constructions. To show our power in multi-modality area, we integrate flamingo component into our system.
Note that this is only for bare-minimum support for model definition. Please check openai_api_multimodal branch for e2e, and #1123 (comment) for better structure and llama3.1 support

Copy link

pytorch-bot bot commented Aug 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchchat/1068

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b96bf05 with merge base c272df4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 27, 2024
Copy link
Contributor

@Jack-Khuu Jack-Khuu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great!! Thanks for making it so easy to review

Mostly nits, but I know this is part of a PR stack we're landing, so we have some leeway

.DS_Store Outdated
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might have to manually tell git not to track

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few DS Stores

@dataclass
class ModelRecipe:
model_type: ModelType
modules: dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
modules: dict
modules: Dict[str, Any]



@dataclass
class ModelRecipe:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Docstring

fusion_class: torch.nn.Module

@classmethod
def text_only(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def text_only(cls):
def _text_only(cls):

fusion_class=nn.Identity,
)
@classmethod
def flamingo(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def flamingo(cls):
def _flamingo(cls):


self.model_type = model_type
if isinstance(transformer_args, TransformerArgs):
self.transformer_args = {"text": transformer_args}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make "text" a constant as well, we use it in a lot of places

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the different PR if you think it is good. I have some plans to make the configuration more concise and structual.


return cls(text_transformer_args)
# now only support flamingo model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# now only support flamingo model
# Currently only supporting flamingo model

@@ -77,13 +124,32 @@ def from_params(cls, params):
params[_to] = params.pop(_from)
return cls(**params)


@dataclass
class ModelArgs:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring


def setup_caches(self, max_batch_size, max_seq_length):
self.text_transformer.setup_caches(max_batch_size, max_seq_length)
def build_model(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is where all the magic comes together; Let's add a docstring here

@@ -184,13 +250,48 @@ class Model(nn.Module):
def __init__(self, config: ModelArgs) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have the legacy text_transformer and new model, let's add a quick description until we unify them later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or link to your other PR where you fix this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will happen adjacent to this PR (3.1 torchtune one), so i'd like to keep as it is here.

@Jack-Khuu
Copy link
Contributor

Jack-Khuu commented Sep 10, 2024

Before landing:

  • update the PR title and description
  • Check that the existing text only model flow works (walk through README)

dist_run.py Outdated
@@ -122,7 +122,7 @@ def main():
gpu_memory_monitor = GPUMemoryMonitor("cuda")
logger.info(f"{color.yellow} {gpu_memory_monitor.get_device_info()}{color.reset}")

config = ModelArgs.from_name(MODEL_NAME).text_transformer_args
config = ModelArgs.from_name(MODEL_NAME)..transformer_args['text']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double dot


@dataclass
class ModelRecipe:
"""
A class in TorchChat that describes and contains all supported model structures in TorchChat.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A class in TorchChat that describes and contains all supported model structures in TorchChat.
A class in torchchat that describes and contains all supported model structures in torchchat.

A class in TorchChat that describes and contains all supported model structures in TorchChat.

ModelRecipe represents a model as a collection of Transformer modules and a fusion module,
providing a standardized and centralized way to define and build models in TorchChat.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
providing a standardized and centralized way to define and build models in TorchChat.
providing a standardized and centralized way to define and build models in torchchat.

@@ -247,6 +267,9 @@ def update(self, input_pos, k_val, v_val):


class Model(nn.Module):
"""
The entrance for model construction in tochchat.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The entrance for model construction in tochchat.
The entrance for model construction in torchchat.

@Gasoonjia Gasoonjia changed the title Flamingo component multi-modality model construction support Sep 11, 2024
@Gasoonjia Gasoonjia merged commit 964d437 into main Sep 11, 2024
51 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants