Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
[06/24/2023] VAE fine-tuning runs successfully but will need to test/evaluate image results. |
| --dataset_name="<DATASET_NAME>" \ | ||
| --train_batch_size=1 \ | ||
| --gradient_accumulation_steps=4 \ | ||
| --gradient_checkpointing |
There was a problem hiding this comment.
| --gradient_checkpointing | |
| --gradient_checkpointing \ |
| with accelerator.main_process_first(): | ||
| # Split into train/test | ||
| dataset = dataset["train"].train_test_split(test_size=args.test_samples) | ||
| # Set the training transforms | ||
| train_dataset = dataset["train"].with_transform(preprocess) | ||
| test_dataset = dataset["test"].with_transform(preprocess) |
There was a problem hiding this comment.
support loading test set from test_data_set folder
| with accelerator.main_process_first(): | |
| # Split into train/test | |
| dataset = dataset["train"].train_test_split(test_size=args.test_samples) | |
| # Set the training transforms | |
| train_dataset = dataset["train"].with_transform(preprocess) | |
| test_dataset = dataset["test"].with_transform(preprocess) | |
| with accelerator.main_process_first(): | |
| # Load test data from test_data_dir | |
| if(args.test_data_dir is not None and args.train_data_dir is not None): | |
| logger.info(f"load test data from {args.test_data_dir}") | |
| test_dir = os.path.join(args.test_data_dir, "**") | |
| test_dataset = load_dataset( | |
| "imagefolder", | |
| data_files=test_dir, | |
| cache_dir=args.cache_dir, | |
| ) | |
| # Set the training transforms | |
| train_dataset = dataset["train"].with_transform(preprocess) | |
| test_dataset = test_dataset["train"].with_transform(preprocess) | |
| # Split into train/test | |
| else: | |
| dataset = dataset["train"].train_test_split(test_size=args.test_samples) | |
| # Set the training transforms | |
| train_dataset = dataset["train"].with_transform(preprocess) | |
| test_dataset = dataset["test"].with_transform(preprocess) |
| type=int, | ||
| default=4, | ||
| help="Number of images to remove from training set to be used as validation.", | ||
| ) |
There was a problem hiding this comment.
add new argument test_data_dir, for dedicated test data folder
| ) | |
| ) | |
| parser.add_argument( | |
| "--test_data_dir", | |
| type=str, | |
| default=None, | |
| help=( | |
| "If not None, will override test_samples arg and use data inside this dir as test dataset." | |
| ), | |
| ) |
| if tracker.name == "tensorboard": | ||
| np_images = np.stack([np.asarray(img) for img in images]) | ||
| tracker.writer.add_images( | ||
| "Original (left) / Reconstruction (right)", np_images, epoch |
There was a problem hiding this comment.
change file name to be compatible with Windows
| "Original (left) / Reconstruction (right)", np_images, epoch | |
| "Original (left)-Reconstruction (right)", np_images, epoch |
| progress_bar.set_description("Steps") | ||
|
|
||
| lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device) | ||
|
|
There was a problem hiding this comment.
| #initial validation as baseline | |
| with torch.no_grad(): | |
| log_validation(test_dataloader, vae, accelerator, weight_dtype, 0) |
one validation before training start as baseline for comparison.
Co-authored-by: zhuliyi0 <48817897+zhuliyi0@users.noreply.github.com>
| pred = vae.decode(z).sample | ||
|
|
||
| kl_loss = posterior.kl().mean() | ||
| mse_loss = F.mse_loss(pred, target, reduction="mean") |
There was a problem hiding this comment.
In original stable-diffusion repo and SDXL repo, the vae loss is averaged over batch dim, which means they are summed in channelheightwidth dims. Is this the correct way to average reconstruction loss?
https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/losses/contperceptual.py#L58
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
|
Is there any progress now? |
PR for Issue #3726
Todos