Skip to content

Commit 6aa6537

Browse files
committed
update
1 parent 3563439 commit 6aa6537

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

examples/vae/train_vae.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from packaging import version
2727
from torchvision import transforms
2828
from tqdm.auto import tqdm
29-
from transformers import CLIPTextModel, CLIPTokenizer
30-
from transformers.utils import ContextManagers
3129

3230
import diffusers
3331
from diffusers import AutoencoderKL
@@ -50,8 +48,8 @@ def log_validation(test_dataloader, vae, accelerator, weight_dtype, epoch):
5048
vae_model = accelerator.unwrap_model(vae)
5149
images = []
5250
for _, sample in enumerate(test_dataloader):
53-
images = sample["pixel_values"].to(weight_dtype)
54-
reconstructions = vae_model(images).sample
51+
x = sample["pixel_values"].to(weight_dtype)
52+
reconstructions = vae_model(x).sample
5553
images.append(
5654
torch.cat([sample["pixel_values"].cpu(), reconstructions.cpu()], axis=0)
5755
)

0 commit comments

Comments
 (0)