Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revision of the MoCo SSL model #928

Merged
merged 25 commits into from
Jul 10, 2023

Conversation

senarvi
Copy link
Contributor

@senarvi senarvi commented Oct 31, 2022

What does this PR do?

Related to issue #839 that discusses a major revision of Bolts, this pull request refactors the Momentum Contrast (MoCo) SSL model. That issue will probably become impossible to follow if all the details are discussed there, so I'll write my suggestions below. I would be happy to hear some feedback.

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@github-actions github-actions bot added documentation Improvements or additions to documentation model labels Oct 31, 2022
@senarvi
Copy link
Contributor Author

senarvi commented Oct 31, 2022

@Atharva-Phatak SimCLR is a very similar method, so it makes sense to coordinate our efforts, and maybe other SSL methods too, so that we could make the interfaces as consistent as possible. We could also place both models in pl_bolts.models.ssl.contrastive module in order to be able to share code, if that's a good idea.

I think it would be important to make the models generic so that one can train

  • any network, instead of selecting one of a set of options
  • using any projection head
  • using any optimizer and learning rate schedule
  • using any data loader and transforms

Issues #904 already suggests moving transforms to a central location. I don't know if we need separate transforms for each model. More or less identical transforms are used with both methods, and the user doesn't necessarily want to stick to the default transforms anyway.

We should fix the format that the data loader should produce for contrastive learning. The format could be as similar as possible to the format that we use for classification and object detection. At least for object detection we use a tuple (images, targets), where images is a list of image tensors and SSL models can simply ignore targets. In this pull request images is expected to be a nested list [[image1A, image1B], [image2A, image2B], [image3A, image3B], ...].

I added an example CLI application using LightningCLI, but IMO there's no point in trying to create a comprehensive training tool in each module.

I renamed the class to MoCo. I think v2 just adds a projection head, so the same class supports both versions.

@senarvi
Copy link
Contributor Author

senarvi commented Oct 31, 2022

The old way of detecting whether DDP is in use didn't work, since the DDP strategy class is now DDPStrategy, not DDPPlugin. Using isinstance() is pretty error-prone. I wonder if we could use if torch.distributed.is_available() and torch.distributed.is_initialized(), or is it important to detect DDP as opposed to other distributed training strategies?

@Atharva-Phatak
Copy link
Contributor

@senarvi Perfect, that is what my intention is essentially we add Projection heads and Losses required by SSL models. Backbones can be implemented by the user if they are custom or else imported from timm or something similar.

All in all here is what I suggest

  1. Single place for SSL projection heads
  2. Single place for SSL losses
  3. Single place for transforms as well.
  4. General PL module for training any custom SSL model :)

@mergify mergify bot added the has conflicts label Nov 1, 2022
@otaj
Copy link
Contributor

otaj commented Nov 2, 2022

Hi, @senarvi, @Atharva-Phatak. I fully support @Atharva-Phatak's suggestion, if you can coordinate on that, that would be great!

Btw, @senarvi, since your fork is on an organization account (groke-technologies) and not on personal account, noone except people belonging to that organization have write access to it. While that makes sense, I'd say it makes it hard for us maintainers to help you with those PRs since we cannot merge master.

@mergify mergify bot removed the has conflicts label Nov 3, 2022
@senarvi
Copy link
Contributor Author

senarvi commented Nov 3, 2022

@otaj ok that explains it. But it's fine, I'll just keep merging master myself.

@mergify mergify bot added the has conflicts label Nov 4, 2022
@ArnolFokam
Copy link
Contributor

@senarvi Perfect, that is what my intention is essentially we add Projection heads and Losses required by SSL models. Backbones can be implemented by the user if they are custom or else imported from timm or something similar.

All in all here is what I suggest

  1. Single place for SSL projection heads
  2. Single place for SSL losses
  3. Single place for transforms as well.
  4. General PL module for training any custom SSL model :)

Happy to help with this as well. Something we might want to be cautious of (if we follow this route) is the compatibility between the transforms and losses of various SSL methods.

@stale stale bot added the won't fix This will not be worked on label Feb 18, 2023
@stale stale bot closed this Mar 18, 2023
@Borda Borda reopened this Mar 18, 2023
@Lightning-Universe Lightning-Universe deleted a comment from stale bot Mar 18, 2023
@stale stale bot removed the won't fix This will not be worked on label Mar 18, 2023
@senarvi
Copy link
Contributor Author

senarvi commented Jul 2, 2023

@Borda the problem was that the data module used in the unit tests provided 4-dimensional tensors of image pairs, while the default ImageNet data module provided tuples of 3-dimensional tensors and that's what the input validator expected too. Now it supports both.

There was also a new unit test that checks the CLI scripts. It failed because ImageNet cannot be downloaded automatically, so I changed the CLI application to use CIFAR10 instead.

If I'm correct, now all the unit tests pass. I can only see some Docker errors now. But let me know if there's still something that needs fixing.

@mergify mergify bot added the ready label Jul 2, 2023
@Borda
Copy link
Member

Borda commented Jul 4, 2023

@senarvi the GPU start issue is fixed on master...

@senarvi
Copy link
Contributor Author

senarvi commented Jul 4, 2023

@Borda merged master and now everything passes.


from pl_bolts.datamodules import VOCDetectionDataModule

LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
cli_class(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cli_class(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)
LightningCLI(RetinaNet, VOCDetectionDataModule, seed_everything_default=42)

@senarvi
Copy link
Contributor Author

senarvi commented Jul 4, 2023

@Borda the thing that you commented on is a real PITA to get working so that mypy won't complain in any case. There's a lot of talk about this on GitHub, see for example this. The problem is that mypy doesn't consider the try..except and complains about LightningCLI being already defined. If # type: ignore comments are added, it will complain about an unused 'type: ignore' comment, because always one of the comments is unnecessary. I'm not sure if there's a fix in the latest mypy. At least with mypy 1.1.1 the code that you suggested won't pass.

@senarvi
Copy link
Contributor Author

senarvi commented Jul 4, 2023

@Borda also, the exception can be either an ImportError or an AttributeError, so it's an ugly workaround indeed.

@senarvi
Copy link
Contributor Author

senarvi commented Jul 7, 2023

@Borda should we go like this? I'm fine with anything that makes mypy happy.

@Borda
Copy link
Member

Borda commented Jul 7, 2023

I'm fine with anything that makes mypy happy.

lets mypy/typing address in separate PR :)

@mergify mergify bot removed the ready label Jul 8, 2023
@mergify mergify bot added the ready label Jul 8, 2023
@senarvi
Copy link
Contributor Author

senarvi commented Jul 8, 2023

@Borda fine. Actually it was necessary to only support the current location (pytorch_lightning.cli) to make the tests pass. Now it's clean.

@Borda Borda merged commit ae4c342 into Lightning-Universe:master Jul 10, 2023
Borda pushed a commit that referenced this pull request Jul 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation model ready
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants