Skip to content

Commit 1d3eaa6

Browse files
authored
Add training support for SigLIP (huggingface#31495)
* Add siglip loss function * Update docs * Enable training tests [experimental] enable GC training tests as it has worked for my own data * Remove test_training* overrides to enable training tests [run_slow] siglip * Skip training tests for Siglip text model and ImageClassificationModel [run_slow] siglip * Skip GC training tests for SiglipForImageClassification * Explicitly skip training tests for SiglipVisionModel Add skip reason for training tests for SiglipTextModel * Remove copied from to fix CI
1 parent 1556025 commit 1d3eaa6

File tree

3 files changed

+11
-30
lines changed

3 files changed

+11
-30
lines changed

docs/source/en/model_doc/siglip.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The abstract from the paper is the following:
2727
## Usage tips
2828

2929
- Usage of SigLIP is similar to [CLIP](clip). The main difference is the training loss, which does not require a global view of all the pairwise similarities of images and texts within a batch. One needs to apply the sigmoid activation function to the logits, rather than the softmax.
30-
- Training is not yet supported. If you want to fine-tune SigLIP or train from scratch, refer to the loss function from [OpenCLIP](https://github.com/mlfoundations/open_clip/blob/73ad04ae7fb93ede1c02dc9040a828634cb1edf1/src/open_clip/loss.py#L307), which leverages various `torch.distributed` utilities.
30+
- Training is supported but does not use `torch.distributed` utilities which may limit the scalability of batch size. However, DDP and FDSP works on single-node multi-gpu setup.
3131
- When using the standalone [`SiglipTokenizer`] or [`SiglipProcessor`], make sure to pass `padding="max_length"` as that's how the model was trained.
3232
- To get the same results as the pipeline, a prompt template of "This is a photo of {label}." should be used.
3333

src/transformers/models/siglip/modeling_siglip.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,12 @@ def forward(
12341234

12351235
loss = None
12361236
if return_loss:
1237-
raise NotImplementedError("SigLIP loss to be implemented")
1237+
# Adapted from https://github.com/google-research/big_vision/blob/01edb81a4716f93a48be43b3a4af14e29cdb3a7f/big_vision/trainers/proj/image_text/siglip.py#L287
1238+
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
1239+
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
1240+
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
1241+
nll = -torch.sum(loglik, dim=-1)
1242+
loss = nll.mean()
12381243

12391244
if not return_dict:
12401245
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)

tests/models/siglip/test_modeling_siglip.py

+4-28
Original file line numberDiff line numberDiff line change
@@ -335,27 +335,19 @@ def test_model(self):
335335
config_and_inputs = self.model_tester.prepare_config_and_inputs()
336336
self.model_tester.create_and_check_model(*config_and_inputs)
337337

338-
@unittest.skip
339-
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training
338+
@unittest.skip(reason="SiglipTextModel does not support standalone training")
340339
def test_training(self):
341340
pass
342341

343-
@unittest.skip
344-
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing
342+
@unittest.skip(reason="SiglipTextModel does not support standalone training")
345343
def test_training_gradient_checkpointing(self):
346344
pass
347345

348-
@unittest.skip(
349-
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
350-
)
351-
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant
346+
@unittest.skip(reason="SiglipTextModel does not support standalone training")
352347
def test_training_gradient_checkpointing_use_reentrant(self):
353348
pass
354349

355-
@unittest.skip(
356-
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
357-
)
358-
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false
350+
@unittest.skip(reason="SiglipTextModel does not support standalone training")
359351
def test_training_gradient_checkpointing_use_reentrant_false(self):
360352
pass
361353

@@ -481,22 +473,6 @@ def test_retain_grad_hidden_states_attentions(self):
481473
def test_model_get_set_embeddings(self):
482474
pass
483475

484-
@unittest.skip(reason="SiglipModel does not support training")
485-
def test_training(self):
486-
pass
487-
488-
@unittest.skip(reason="SiglipModel does not support training")
489-
def test_training_gradient_checkpointing(self):
490-
pass
491-
492-
@unittest.skip(reason="SiglipModel does not support training")
493-
def test_training_gradient_checkpointing_use_reentrant(self):
494-
pass
495-
496-
@unittest.skip(reason="SiglipModel does not support training")
497-
def test_training_gradient_checkpointing_use_reentrant_false(self):
498-
pass
499-
500476
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
501477
def test_initialization(self):
502478
pass

0 commit comments

Comments
 (0)