Skip to content

VAE training sample script#3801

Closed
aandyw wants to merge 29 commits intohuggingface:mainfrom
aandyw:vae-training
Closed

VAE training sample script#3801
aandyw wants to merge 29 commits intohuggingface:mainfrom
aandyw:vae-training

Conversation

@aandyw
Copy link
Contributor

@aandyw aandyw commented Jun 15, 2023

PR for Issue #3726

Todos

  • implement training loop for VAE
  • KL loss implementation
  • evaluate performance of VAE training
  • fix script to work for mixed precision
  • integration with a1111

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@aandyw
Copy link
Contributor Author

aandyw commented Jun 24, 2023

[06/24/2023] VAE fine-tuning runs successfully but will need to test/evaluate image results.

@aandyw aandyw marked this pull request as ready for review June 24, 2023 19:13
--dataset_name="<DATASET_NAME>" \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
--gradient_checkpointing
--gradient_checkpointing \

@aandyw aandyw changed the title [WIP] VAE training sample script VAE training sample script Jul 27, 2023
Comment on lines +390 to +395
with accelerator.main_process_first():
# Split into train/test
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)
Copy link

@zhuliyi0 zhuliyi0 Jul 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

support loading test set from test_data_set folder

Suggested change
with accelerator.main_process_first():
# Split into train/test
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)
with accelerator.main_process_first():
# Load test data from test_data_dir
if(args.test_data_dir is not None and args.train_data_dir is not None):
logger.info(f"load test data from {args.test_data_dir}")
test_dir = os.path.join(args.test_data_dir, "**")
test_dataset = load_dataset(
"imagefolder",
data_files=test_dir,
cache_dir=args.cache_dir,
)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = test_dataset["train"].with_transform(preprocess)
# Split into train/test
else:
dataset = dataset["train"].train_test_split(test_size=args.test_samples)
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess)
test_dataset = dataset["test"].with_transform(preprocess)

type=int,
default=4,
help="Number of images to remove from training set to be used as validation.",
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add new argument test_data_dir, for dedicated test data folder

Suggested change
)
)
parser.add_argument(
"--test_data_dir",
type=str,
default=None,
help=(
"If not None, will override test_samples arg and use data inside this dir as test dataset."
),
)

if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(
"Original (left) / Reconstruction (right)", np_images, epoch

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change file name to be compatible with Windows

Suggested change
"Original (left) / Reconstruction (right)", np_images, epoch
"Original (left)-Reconstruction (right)", np_images, epoch

progress_bar.set_description("Steps")

lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#initial validation as baseline
with torch.no_grad():
log_validation(test_dataloader, vae, accelerator, weight_dtype, 0)

one validation before training start as baseline for comparison.

aandyw and others added 2 commits August 5, 2023 18:39
Co-authored-by: zhuliyi0 <48817897+zhuliyi0@users.noreply.github.com>
pred = vae.decode(z).sample

kl_loss = posterior.kl().mean()
mse_loss = F.mse_loss(pred, target, reduction="mean")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In original stable-diffusion repo and SDXL repo, the vae loss is averaged over batch dim, which means they are summed in channelheightwidth dims. Is this the correct way to average reconstruction loss?
https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/losses/contperceptual.py#L58

@github-actions
Copy link
Contributor

github-actions bot commented Sep 2, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot added the stale Issues that haven't received updates label Sep 2, 2023
@github-actions github-actions bot closed this Sep 12, 2023
Copy link

@JunzheJosephZhu JunzheJosephZhu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@lavinal712
Copy link
Contributor

Is there any progress now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Issues that haven't received updates

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants