-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support channels_last
with training
#15175
Comments
how about doing it manually? class LitModel(LightningModule):
def __init__(self):
...
def on_fit_start(self):
self = self.to(memory_format=torch.channels_last) |
I think it would probably work but:
|
I think it's very domain specific plus just one linear so, I don't think this should be added to the trainer |
How about a plugin? |
Keep in mind you don't just call |
class LitModel(LightningModule):
def __init__(self):
...
def on_fit_start(self):
self = self.to(memory_format=torch.channels_last)
def on_after_batch_transfer(self, batch, *args, **kwargs):
batch = batch.to(memory_format=torch.channels_last)
return batch haven't checked more in defail, but based on your info, I guess this should work |
Yeah you could also write
Right? That's not really the point though This isn't some niche technology, huggingface used it to speed up stable diffusion inference by ~2.8x (https://huggingface.co/docs/diffusers/optimization/fp16) I assume that's a pie lightning wants a piece of So does anyone besides you have opinions on this? |
cc: @JustinGoheen @justusschock |
Hello all. Any update on this? I've been using a simple callback and getting around 30%-40% speedup while training a torchvision ResNet50. I could open a PR if anyone is interested. |
Sounds good to me |
@Pedrexus Should the callback also do?
If the callback handles this boilerplate, then I see value in a callback like this, but otherwise, |
Another alternative would be to have this in the docs as an example of a callback |
sure, I agree with you. The one I use in my codebase does a bit more than just the one liner, I think I can add these soon. However, I'm not sure input conversion is necessary. I tried this and saw no positive performance effect. I could add it behind a feature flag anyway. BTW, I could make it a bit more general as "MemoryFormat" callback and not only "ChannelsLast". I believe it shouldn't be over engineering. |
@Pedrexus |
@TezRomacH I'm sorry. I got busy with things, but I will try to work on the failing checks soon. |
Ok, I fixed some problems, but it seems mypy is failing. From what I glanced, it seems it is using the wrong overload of torch.Tensor.to() https://pytorch.org/docs/stable/generated/torch.Tensor.to.html |
🚀 Feature
I'd like to try out some
channels_last
training to see if it improves performance (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)I'm not entirely sure what the best way to do this with lightning is but I also think it should probably be one of those features that you set on the trainer and it just magically works.
Motivation
Using the
channels_last
memory format can improve performance in some casesPitch
Add a trainer flag that does whatever is needed for
channels_last
soor something like that and then before anything happens with training/testing you need to convert the module
and each batch in the train/test/val loops
Alternatives
I have no idea but I assume I could do this without changing lightning although I'm not sure how yet
Additional context
I am not sure I will be able to PR this one but I'm not opposed to trying
cc @Borda @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: