Skip to content

WIP: multimodal support #227

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 83 commits into
base: main
Choose a base branch
from
Draft

WIP: multimodal support #227

wants to merge 83 commits into from

Conversation

sohamparikh
Copy link
Member

✨ Description

Please provide a brief summary of the changes, relevant motivation, and context.
Include any related issue numbers or links to discussions, and explain why this change is necessary.

Closes #

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

List the key changes introduced in this PR:

  1. Change A
  2. Change B

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

📊 Performance Impact Details

If there is any impact on performance, describe it and provide benchmark results, if applicable:


🗒️ Additional Notes

Include any additional context, information, or considerations here, such as known issues, follow-up tasks, or backward compatibility concerns.

@tscholak tscholak mentioned this pull request May 9, 2025
Copy link
Contributor

@RaymondLi0 RaymondLi0 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work! 🚀
Couldn't look at everything yet, but here are some first comments, will continue tomorrow

@@ -298,11 +305,7 @@ def run(self) -> None:
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._config.dataset.loss_masking_spans is not None:
if self._config.dataset.loss_masking_spans not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._config.dataset.loss_masking_spans}'.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i moved everything to a single tokenize function, since the combinations were getting too much (images, loss spans, audio next)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we still keep the check of the column-names?

Co-authored-by: sohamparikh <sohamparikh47@gmail.com>
@sohamparikh sohamparikh mentioned this pull request Jun 11, 2025
25 tasks
@@ -548,6 +563,350 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig
]


class PixtralHuggingfaceCheckpointHandler(WeightAndBiasConverterMixin, HuggingfaceStateDictCheckpointHandler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that this class handles conversion for the vision encoder (to a PixtralVisionModel model),
whereas LlavaHuggingfaceCheckpointHandler is the class that handles conversion of the full model, and is the one to use to convert pixtral-12b ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly! It's inspired by the config.json on HF, and increasingly more omni models seem to be converging to a similar format.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants