Skip to content

Commit 1205afc

Browse files
Better training loop implementation (Comfy-Org#8820)
1 parent 5612670 commit 1205afc

File tree

1 file changed

+81
-43
lines changed

1 file changed

+81
-43
lines changed

comfy_extras/nodes_train.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,78 @@
2323
from comfy.weight_adapter import adapters
2424

2525

26+
def make_batch_extra_option_dict(d, indicies, full_size=None):
27+
new_dict = {}
28+
for k, v in d.items():
29+
newv = v
30+
if isinstance(v, dict):
31+
newv = make_batch_extra_option_dict(v, indicies, full_size=full_size)
32+
elif isinstance(v, torch.Tensor):
33+
if full_size is None or v.size(0) == full_size:
34+
newv = v[indicies]
35+
elif isinstance(v, (list, tuple)) and len(v) == full_size:
36+
newv = [v[i] for i in indicies]
37+
new_dict[k] = newv
38+
return new_dict
39+
40+
2641
class TrainSampler(comfy.samplers.Sampler):
2742

28-
def __init__(self, loss_fn, optimizer, loss_callback=None):
43+
def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, total_steps=1, seed=0, training_dtype=torch.bfloat16):
2944
self.loss_fn = loss_fn
3045
self.optimizer = optimizer
3146
self.loss_callback = loss_callback
47+
self.batch_size = batch_size
48+
self.total_steps = total_steps
49+
self.seed = seed
50+
self.training_dtype = training_dtype
3251

3352
def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False):
34-
self.optimizer.zero_grad()
35-
noise = model_wrap.inner_model.model_sampling.noise_scaling(sigmas, noise, latent_image, False)
36-
latent = model_wrap.inner_model.model_sampling.noise_scaling(
37-
torch.zeros_like(sigmas),
38-
torch.zeros_like(noise, requires_grad=True),
39-
latent_image,
40-
False
41-
)
53+
cond = model_wrap.conds["positive"]
54+
dataset_size = sigmas.size(0)
55+
torch.cuda.empty_cache()
56+
for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
57+
noisegen = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(self.seed + i * 1000)
58+
indicies = torch.randperm(dataset_size)[:self.batch_size].tolist()
59+
60+
batch_latent = torch.stack([latent_image[i] for i in indicies])
61+
batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device)
62+
batch_sigmas = [
63+
model_wrap.inner_model.model_sampling.percent_to_sigma(
64+
torch.rand((1,)).item()
65+
) for _ in range(min(self.batch_size, dataset_size))
66+
]
67+
batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device)
68+
69+
xt = model_wrap.inner_model.model_sampling.noise_scaling(
70+
batch_sigmas,
71+
batch_noise,
72+
batch_latent,
73+
False
74+
)
75+
x0 = model_wrap.inner_model.model_sampling.noise_scaling(
76+
torch.zeros_like(batch_sigmas),
77+
torch.zeros_like(batch_noise),
78+
batch_latent,
79+
False
80+
)
4281

43-
# Ensure model is in training mode and computing gradients
44-
# x0 pred
45-
denoised = model_wrap(noise, sigmas, **extra_args)
46-
try:
47-
loss = self.loss_fn(denoised, latent.clone())
48-
except RuntimeError as e:
49-
if "does not require grad and does not have a grad_fn" in str(e):
50-
logging.info("WARNING: This is likely due to the model is loaded in inference mode.")
51-
loss.backward()
52-
if self.loss_callback:
53-
self.loss_callback(loss.item())
54-
55-
self.optimizer.step()
56-
# torch.cuda.memory._dump_snapshot("trainn.pickle")
57-
# torch.cuda.memory._record_memory_history(enabled=None)
82+
model_wrap.conds["positive"] = [
83+
cond[i] for i in indicies
84+
]
85+
batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size)
86+
87+
with torch.autocast(xt.device.type, dtype=self.training_dtype):
88+
x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args)
89+
loss = self.loss_fn(x0_pred, x0)
90+
loss.backward()
91+
if self.loss_callback:
92+
self.loss_callback(loss.item())
93+
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
94+
95+
self.optimizer.step()
96+
self.optimizer.zero_grad()
97+
torch.cuda.empty_cache()
5898
return torch.zeros_like(latent_image)
5999

60100

@@ -584,36 +624,34 @@ def train(
584624
loss_map = {"loss": []}
585625
def loss_callback(loss):
586626
loss_map["loss"].append(loss)
587-
pbar.set_postfix({"loss": f"{loss:.4f}"})
588627
train_sampler = TrainSampler(
589-
criterion, optimizer, loss_callback=loss_callback
628+
criterion,
629+
optimizer,
630+
loss_callback=loss_callback,
631+
batch_size=batch_size,
632+
total_steps=steps,
633+
seed=seed,
634+
training_dtype=dtype
590635
)
591636
guider = comfy_extras.nodes_custom_sampler.Guider_Basic(mp)
592637
guider.set_conds(positive) # Set conditioning from input
593638

594-
# yoland: this currently resize to the first image in the dataset
595-
596639
# Training loop
597-
torch.cuda.empty_cache()
598640
try:
599-
for step in (pbar:=tqdm.trange(steps, desc="Training LoRA", smoothing=0.01, disable=not comfy.utils.PROGRESS_BAR_ENABLED)):
600-
# Generate random sigma
601-
sigmas = [mp.model.model_sampling.percent_to_sigma(
602-
torch.rand((1,)).item()
603-
) for _ in range(min(batch_size, num_images))]
604-
sigmas = torch.tensor(sigmas)
605-
606-
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(step * 1000 + seed)
607-
608-
indices = torch.randperm(num_images)[:batch_size]
609-
batch_latent = latents[indices].clone()
610-
guider.set_conds([positive[i] for i in indices]) # Set conditioning from input
611-
guider.sample(noise.generate_noise({"samples": batch_latent}), batch_latent, train_sampler, sigmas, seed=noise.seed)
641+
# Generate dummy sigmas and noise
642+
sigmas = torch.tensor(range(num_images))
643+
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
644+
guider.sample(
645+
noise.generate_noise({"samples": latents}),
646+
latents,
647+
train_sampler,
648+
sigmas,
649+
seed=noise.seed
650+
)
612651
finally:
613652
for m in mp.model.modules():
614653
unpatch(m)
615654
del train_sampler, optimizer
616-
torch.cuda.empty_cache()
617655

618656
for adapter in all_weight_adapters:
619657
adapter.requires_grad_(False)

0 commit comments

Comments
 (0)