|
23 | 23 | from comfy.weight_adapter import adapters |
24 | 24 |
|
25 | 25 |
|
| 26 | +def make_batch_extra_option_dict(d, indicies, full_size=None): |
| 27 | + new_dict = {} |
| 28 | + for k, v in d.items(): |
| 29 | + newv = v |
| 30 | + if isinstance(v, dict): |
| 31 | + newv = make_batch_extra_option_dict(v, indicies, full_size=full_size) |
| 32 | + elif isinstance(v, torch.Tensor): |
| 33 | + if full_size is None or v.size(0) == full_size: |
| 34 | + newv = v[indicies] |
| 35 | + elif isinstance(v, (list, tuple)) and len(v) == full_size: |
| 36 | + newv = [v[i] for i in indicies] |
| 37 | + new_dict[k] = newv |
| 38 | + return new_dict |
| 39 | + |
| 40 | + |
26 | 41 | class TrainSampler(comfy.samplers.Sampler): |
27 | 42 |
|
28 | | - def __init__(self, loss_fn, optimizer, loss_callback=None): |
| 43 | + def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): |
29 | 44 | self.loss_fn = loss_fn |
30 | 45 | self.optimizer = optimizer |
31 | 46 | self.loss_callback = loss_callback |
| 47 | + self.batch_size = batch_size |
| 48 | + self.total_steps = total_steps |
| 49 | + self.seed = seed |
| 50 | + self.training_dtype = training_dtype |
32 | 51 |
|
33 | 52 | def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): |
34 | | - self.optimizer.zero_grad() |
35 | | - noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False) |
36 | | - latent = model_wrap.inner_model.model_sampling.noise_scaling( |
37 | | - torch.zeros_like(sigmas), |
38 | | - torch.zeros_like(noise, requires_grad=True), |
39 | | - latent_image, |
40 | | - False |
41 | | - ) |
| 53 | + cond = model_wrap.conds["positive"] |
| 54 | + dataset_size = sigmas.size(0) |
| 55 | + torch.cuda.empty_cache() |
| 56 | + for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): |
| 57 | + noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000) |
| 58 | + indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() |
| 59 | + |
| 60 | + batch_latent = torch.stack([latent_image[i] for i in indicies]) |
| 61 | + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) |
| 62 | + batch_sigmas = [ |
| 63 | + model_wrap.inner_model.model_sampling.percent_to_sigma( |
| 64 | + torch.rand((1,)).item() |
| 65 | + ) for _ in range(min(self.batch_size, dataset_size)) |
| 66 | + ] |
| 67 | + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) |
| 68 | + |
| 69 | + xt = model_wrap.inner_model.model_sampling.noise_scaling( |
| 70 | + batch_sigmas, |
| 71 | + batch_noise, |
| 72 | + batch_latent, |
| 73 | + False |
| 74 | + ) |
| 75 | + x0 = model_wrap.inner_model.model_sampling.noise_scaling( |
| 76 | + torch.zeros_like(batch_sigmas), |
| 77 | + torch.zeros_like(batch_noise), |
| 78 | + batch_latent, |
| 79 | + False |
| 80 | + ) |
42 | 81 |
|
43 | | - # Ensure model is in training mode and computing gradients |
44 | | - # x0 pred |
45 | | - denoised = model_wrap(noise, sigmas, **extra_args) |
46 | | - try: |
47 | | - loss = self.loss_fn(denoised, latent.clone()) |
48 | | - except RuntimeError as e: |
49 | | - if "does not require grad and does not have a grad_fn" in str(e): |
50 | | - logging.info("WARNING: This is likely due to the model is loaded in inference mode.") |
51 | | - loss.backward() |
52 | | - if self.loss_callback: |
53 | | - self.loss_callback(loss.item()) |
54 | | - |
55 | | - self.optimizer.step() |
56 | | - # torch.cuda.memory._dump_snapshot("trainn.pickle") |
57 | | - # torch.cuda.memory._record_memory_history(enabled=None) |
| 82 | + model_wrap.conds["positive"] = [ |
| 83 | + cond[i] for i in indicies |
| 84 | + ] |
| 85 | + batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) |
| 86 | + |
| 87 | + with torch.autocast(xt.device.type, dtype=self.training_dtype): |
| 88 | + x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) |
| 89 | + loss = self.loss_fn(x0_pred, x0) |
| 90 | + loss.backward() |
| 91 | + if self.loss_callback: |
| 92 | + self.loss_callback(loss.item()) |
| 93 | + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) |
| 94 | + |
| 95 | + self.optimizer.step() |
| 96 | + self.optimizer.zero_grad() |
| 97 | + torch.cuda.empty_cache() |
58 | 98 | return torch.zeros_like(latent_image) |
59 | 99 |
|
60 | 100 |
|
@@ -584,36 +624,34 @@ def train( |
584 | 624 | loss_map = {"loss": []} |
585 | 625 | def loss_callback(loss): |
586 | 626 | loss_map["loss"].append(loss) |
587 | | - pbar.set_postfix({"loss": f"{loss:.4f}"}) |
588 | 627 | train_sampler = TrainSampler( |
589 | | - criterion, optimizer, loss_callback=loss_callback |
| 628 | + criterion, |
| 629 | + optimizer, |
| 630 | + loss_callback=loss_callback, |
| 631 | + batch_size=batch_size, |
| 632 | + total_steps=steps, |
| 633 | + seed=seed, |
| 634 | + training_dtype=dtype |
590 | 635 | ) |
591 | 636 | guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp) |
592 | 637 | guider.set_conds(positive) # Set conditioning from input |
593 | 638 |
|
594 | | - # yoland: this currently resize to the first image in the dataset |
595 | | - |
596 | 639 | # Training loop |
597 | | - torch.cuda.empty_cache() |
598 | 640 | try: |
599 | | - for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)): |
600 | | - # Generate random sigma |
601 | | - sigmas = [mp.model.model_sampling.percent_to_sigma( |
602 | | - torch.rand((1,)).item() |
603 | | - ) for _ in range(min(batch_size, num_images))] |
604 | | - sigmas = torch.tensor(sigmas) |
605 | | - |
606 | | - noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed) |
607 | | - |
608 | | - indices = torch.randperm(num_images)[:batch_size] |
609 | | - batch_latent = latents[indices].clone() |
610 | | - guider.set_conds([positive[i] for i in indices]) # Set conditioning from input |
611 | | - guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed) |
| 641 | + # Generate dummy sigmas and noise |
| 642 | + sigmas = torch.tensor(range(num_images)) |
| 643 | + noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) |
| 644 | + guider.sample( |
| 645 | + noise.generate_noise({"samples": latents}), |
| 646 | + latents, |
| 647 | + train_sampler, |
| 648 | + sigmas, |
| 649 | + seed=noise.seed |
| 650 | + ) |
612 | 651 | finally: |
613 | 652 | for m in mp.model.modules(): |
614 | 653 | unpatch(m) |
615 | 654 | del train_sampler, optimizer |
616 | | - torch.cuda.empty_cache() |
617 | 655 |
|
618 | 656 | for adapter in all_weight_adapters: |
619 | 657 | adapter.requires_grad_(False) |
|
0 commit comments