Skip to content

Commit

Permalink
just force train_step to only accept kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 20, 2023
1 parent a9e8ed5 commit 4251f27
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
6 changes: 4 additions & 2 deletions imagen_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,11 +604,13 @@ def create_valid_iter(self):

self.valid_dl_iter = cycle(self.valid_dl)

def train_step(self, unet_number = None, **kwargs):
def train_step(self, *, unet_number = None, **kwargs):
if not self.prepared:
self.prepare()
self.create_train_iter()
loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)

kwargs = {'unet_number': unet_number, **kwargs}
loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs)
self.update(unet_number = unet_number)
return loss

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.23.1'
__version__ = '1.23.2'

0 comments on commit 4251f27

Please sign in to comment.