Uiuc vlm pr compressed fixed#511
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
052a484 to
03a1804
Compare
03a1804 to
0b1d000
Compare
|
@immuntasir - let me know when this is ready for review! |
@Tianjiao-Yu confirmed that this is ready for review. |
examples/dpo_demo_gemma3.ipynb
Outdated
There was a problem hiding this comment.
@Tianjiao-Yu I think this should be removed from this PR.
abheesht17
left a comment
There was a problem hiding this comment.
Quick review, I'll do another pass tomorrow
examples/vl_dpo_demo_gemma3.ipynb
Outdated
| "source": [ | ||
| "# Fine-tuning a Visual Language Model (VLM) using DPO\n", | ||
| "\n", | ||
| "This notebook demonstrates how to fine-tune a Visual Language Model (VLM), specifically the Gemma 3-1B-it model, using the Direct Preference Optimization (DPO) algorithm.\n", |
There was a problem hiding this comment.
Gemma 3-1B-it model
This is a text-only model though. 4B onwards are VLMs
tunix/sft/dpo/dpo_trainer.py
Outdated
|
|
||
| This can be used when inputs are raw strings. Tokenization, padding and | ||
| preprocessing is taken care of by `DPOTrainer`. | ||
| preprocessing is taken care of by `DpoTrainer`. |
examples/dpo_demo_gemma3.ipynb
Outdated
tunix/generate/tokenizer_adapter.py
Outdated
| elif self._tokenizer_type == TokenizerType.HFP: | ||
| inputs = self._tokenizer(text=text, **kwargs) | ||
| if 'images' in kwargs: | ||
| return inputs['input_ids'], inputs['pixel_values'] |
There was a problem hiding this comment.
Better to return a dictionary here rather than a tuple (in case we add more modalities later)?
tunix/generate/tokenizer_adapter.py
Outdated
| HF: str = 'hf' # huggingface tokenizer | ||
| HFP: str = 'hfp' # huggingface processor |
There was a problem hiding this comment.
Is the only difference between these two that the processor can take images, and other modalities too? If yes, do you think we should just use HF processor everywhere (and remove HF tokeniser)?
Because if processor(text) works, we can just use processor everywhere
There was a problem hiding this comment.
I don't think every tokenizer has an associated processor definition, so it probably makes sense to have both.
tunix/generate/utils.py
Outdated
| # Defaults compatible with CLIP / many SigLIP configs; override if needed. | ||
| _CLIP_MEAN = jnp.array([0.48145466, 0.4578275, 0.40821073], dtype=jnp.float32) | ||
| _CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32) |
There was a problem hiding this comment.
Do you think we can move it to models/siglip?
tunix/generate/utils.py
Outdated
| mean: Iterable[float] = _CLIP_MEAN, | ||
| std: Iterable[float] = _CLIP_STD, | ||
| ) -> jnp.ndarray: | ||
| """Resize + normalize images for SigLIP. |
There was a problem hiding this comment.
Just SigLIP? Does it not work for other vision models? In generate/utils.py, we should have generic functions (as much as possible)
tunix/models/gemma3/model.py
Outdated
|
|
||
| if self.config.multimodal: | ||
| assert pixel_values is not None | ||
| image_mask = last_tokens == 262144 # 262144: <image_soft_token> |
There was a problem hiding this comment.
Better to define this somewhere instead of hardcoding
|
Oh, I didn't mean to request so many reviews. Not sure how that happened. Maybe from the CLA failing? |
Looks good. Could you please give me edit access to this branch? I'll resolve merge conflicts and make a few changes (especially regarding the multiple images point). Thanks! |
|
@abheesht17 For the multiple image support, you may want to look at this commit as a reference point (mostly files Another change that you might be interested in is saving LoRA params for multimodal Gemma. Alternatively, I can create a clear pull request for it after the current PR is merged. |
There was a problem hiding this comment.
I was going through this again, and found a few issues:
- Image tokens should have bidirectional attention, but I don't see that in the code.
- We should support multiple images.
- We have a Hugging Face preprocessor (
"hfp"), but we don't seem to be using it. Also, the special tokens in HF preprocessor/tokeniser are different from the upstream GDM implementation. - Gemma 3 uses special start of image tokens, end of image tokens, etc., which are not there in the code.
I have a WIP PR for resolving some of these issues. Give me some time.
tunix/models/siglip/preprocess.py
Outdated
| _CLIP_STD = jnp.array([0.26862954, 0.26130258, 0.27577711], dtype=jnp.float32) | ||
|
|
||
|
|
||
| def preprocess( |
There was a problem hiding this comment.
I don't see this function being used anywhere
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "gemma_tokenizer = tokenizer_lib.Tokenizer(tokenizer_path=GEMMA_TOKENIZER_PATH)" |
There was a problem hiding this comment.
Why can we not use the HF processor directly?
examples/vl_dpo_demo_gemma3.ipynb
Outdated
| "model_config = dataclasses.replace(\n", | ||
| " model_config, multimodal=True, num_embed=262208\n", |
There was a problem hiding this comment.
Why don't we just expose multimodal as an arg in gemma3_model_lib.ModelConfig.gemma3_4b(multimodal=True)?
tunix/models/gemma3/model.py
Outdated
| @@ -927,18 +1001,26 @@ def __call__( | |||
| positions: jaxtyping.Array, # [B, L] | |||
| cache: Cache | None, # (sequence length L') | |||
| attention_mask: jaxtyping.Array, # [B, L, L'] | |||
There was a problem hiding this comment.
Gemma 3 is supposed to have bidirectional attention for image tokens, but I don't see that here, or in the VLM DPO notebook.
…ed SigLIP preprocess
PiperOrigin-RevId: 884468159
Resolves #510
This PR introduces multimodal support to Tunix’s Gemma3 model and adds a new vision-language DPO demonstration notebook (vl_dpo_demo_gemma3.ipynb), extending the framework to handle image-text reasoning and multimodal alignment. Key changes includes:
@Tianjiao-Yu led this effort and @jxiong21029 contributed to the Gemma3 integration. Please also mention @Tianjiao-Yu if you have any questions/comments/feedback.
Colab Notebook
vl_dpo_demo_gemma3.ipynb
Checklist