diff --git a/imagen_pytorch/trainer.py b/imagen_pytorch/trainer.py index 0597330..d3f4f23 100644 --- a/imagen_pytorch/trainer.py +++ b/imagen_pytorch/trainer.py @@ -604,11 +604,13 @@ def create_valid_iter(self): self.valid_dl_iter = cycle(self.valid_dl) - def train_step(self, unet_number = None, **kwargs): + def train_step(self, *, unet_number = None, **kwargs): if not self.prepared: self.prepare() self.create_train_iter() - loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) + + kwargs = {'unet_number': unet_number, **kwargs} + loss = self.step_with_dl_iter(self.train_dl_iter, **kwargs) self.update(unet_number = unet_number) return loss diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index a039e2f..375471f 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.23.1' +__version__ = '1.23.2'