Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support channels_last with training #15175

Open
Queuecumber opened this issue Oct 18, 2022 · 16 comments · May be fixed by #17680
Open

Support channels_last with training #15175

Queuecumber opened this issue Oct 18, 2022 · 16 comments · May be fixed by #17680
Labels
feature Is an improvement or enhancement lightningmodule pl.LightningModule trainer: argument

Comments

@Queuecumber
Copy link
Contributor

Queuecumber commented Oct 18, 2022

🚀 Feature

I'd like to try out some channels_last training to see if it improves performance (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)

I'm not entirely sure what the best way to do this with lightning is but I also think it should probably be one of those features that you set on the trainer and it just magically works.

Motivation

Using the channels_last memory format can improve performance in some cases

Pitch

Add a trainer flag that does whatever is needed for channels_last so

trainer = pl.Trainer( ..., memory_format='channels_last')

or something like that and then before anything happens with training/testing you need to convert the module

if self.memory_format == 'channels_last':
    lightning_module = lightning_module.to(memory_format=torch.channels_last)

and each batch in the train/test/val loops

if self.memory_format == 'channels_last':
    batch = batch.to(memory_format=torch.channels_last)
lightning_module(batch)

Alternatives

I have no idea but I assume I could do this without changing lightning although I'm not sure how yet

Additional context

I am not sure I will be able to PR this one but I'm not opposed to trying

cc @Borda @carmocca @justusschock @awaelchli

@Queuecumber Queuecumber added the needs triage Waiting to be triaged by maintainers label Oct 18, 2022
@awaelchli awaelchli added feature Is an improvement or enhancement lightningmodule pl.LightningModule trainer: argument and removed needs triage Waiting to be triaged by maintainers labels Oct 18, 2022
@rohitgr7
Copy link
Contributor

how about doing it manually?

class LitModel(LightningModule):
    def __init__(self):
        ...

    def on_fit_start(self):
        self = self.to(memory_format=torch.channels_last)

@Queuecumber
Copy link
Contributor Author

I think it would probably work but:

I also think it should probably be one of those features that you set on the trainer and it just magically works.

@rohitgr7
Copy link
Contributor

I think it's very domain specific plus just one linear so, I don't think this should be added to the trainer

@Queuecumber
Copy link
Contributor Author

How about a plugin?

@Queuecumber
Copy link
Contributor Author

Keep in mind you don't just call to on the model itself, you need to call it in all inputs too

@rohitgr7
Copy link
Contributor

class LitModel(LightningModule):
    def __init__(self):
        ...

    def on_fit_start(self):
        self = self.to(memory_format=torch.channels_last)

    def on_after_batch_transfer(self, batch, *args, **kwargs):
        batch = batch.to(memory_format=torch.channels_last)
        return batch

haven't checked more in defail, but based on your info, I guess this should work

@Queuecumber
Copy link
Contributor Author

Yeah you could also write

class LitModel(LightningModule):
    def __init__(self):
        ...

    def train_step(batch):
        with torch.amp.autocast(self.device):
            ...

Right? That's not really the point though

This isn't some niche technology, huggingface used it to speed up stable diffusion inference by ~2.8x (https://huggingface.co/docs/diffusers/optimization/fp16) I assume that's a pie lightning wants a piece of

So does anyone besides you have opinions on this?

@stale stale bot added the won't fix This will not be worked on label Apr 15, 2023
@Borda Borda changed the title Support channels_last Support channels_last with training Apr 17, 2023
@stale stale bot removed the won't fix This will not be worked on label Apr 17, 2023
@Lightning-AI Lightning-AI deleted a comment from stale bot Apr 17, 2023
@Borda
Copy link
Member

Borda commented Apr 17, 2023

This isn't some niche technology, huggingface used it to speed up stable diffusion inference by ~2.8x

cc: @JustinGoheen @justusschock

@Pedrexus
Copy link

Pedrexus commented May 1, 2023

Hello all. Any update on this? I've been using a simple callback and getting around 30%-40% speedup while training a torchvision ResNet50. I could open a PR if anyone is interested.

@HelixPiano
Copy link

Sounds good to me

@Pedrexus Pedrexus linked a pull request May 23, 2023 that will close this issue
11 tasks
@awaelchli
Copy link
Contributor

@Pedrexus Should the callback also do?

  • undo the transformation in teardown
  • handle input conversion
  • warn if no conv layers in model?

If the callback handles this boilerplate, then I see value in a callback like this, but otherwise, model.to(memory_format=torch.channels_last) is a one-liner and should work as expected already.

@carmocca
Copy link
Contributor

Another alternative would be to have this in the docs as an example of a callback

@Pedrexus
Copy link

Pedrexus commented May 30, 2023

@Pedrexus Should the callback also do?

  • undo the transformation in teardown
  • handle input conversion
  • warn if no conv layers in model?

If the callback handles this boilerplate, then I see value in a callback like this, but otherwise, model.to(memory_format=torch.channels_last) is a one-liner and should work as expected already.

sure, I agree with you. The one I use in my codebase does a bit more than just the one liner, I think I can add these soon.

However, I'm not sure input conversion is necessary. I tried this and saw no positive performance effect. I could add it behind a feature flag anyway.

BTW, I could make it a bit more general as "MemoryFormat" callback and not only "ChannelsLast". I believe it shouldn't be over engineering.

@TezRomacH
Copy link
Contributor

@Pedrexus
Hi! Any updates on the callback? :)

@Pedrexus
Copy link

@TezRomacH I'm sorry. I got busy with things, but I will try to work on the failing checks soon.

@Pedrexus
Copy link

Pedrexus commented Dec 13, 2023

Ok, I fixed some problems, but it seems mypy is failing. From what I glanced, it seems it is using the wrong overload of torch.Tensor.to() https://pytorch.org/docs/stable/generated/torch.Tensor.to.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement lightningmodule pl.LightningModule trainer: argument
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants