Skip to content

Commit 2a4328d

Browse files
authored
ace15: Use dynamic_vram friendly trange (Comfy-Org#12409)
Factor out the ksampler trange and use it in ACE LLM to prevent the silent stall at 0 and rate distortion due to first-step model load.
1 parent d297a74 commit 2a4328d

File tree

3 files changed

+30
-32
lines changed

3 files changed

+30
-32
lines changed

comfy/k_diffusion/sampling.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import math
2-
import time
32
from functools import partial
43

54
from scipy import integrate
65
import torch
76
from torch import nn
87
import torchsde
9-
from tqdm.auto import trange as trange_, tqdm
8+
from tqdm.auto import tqdm
109

1110
from . import utils
1211
from . import deis
@@ -15,34 +14,7 @@
1514
import comfy.model_sampling
1615

1716
import comfy.memory_management
18-
19-
20-
def trange(*args, **kwargs):
21-
if comfy.memory_management.aimdo_allocator is None:
22-
return trange_(*args, **kwargs)
23-
24-
pbar = trange_(*args, **kwargs, smoothing=1.0)
25-
pbar._i = 0
26-
pbar.set_postfix_str(" Model Initializing ... ")
27-
28-
_update = pbar.update
29-
30-
def warmup_update(n=1):
31-
pbar._i += 1
32-
if pbar._i == 1:
33-
pbar.i1_time = time.time()
34-
pbar.set_postfix_str(" Model Initialization complete! ")
35-
elif pbar._i == 2:
36-
#bring forward the effective start time based the the diff between first and second iteration
37-
#to attempt to remove load overhead from the final step rate estimate.
38-
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
39-
pbar.set_postfix_str("")
40-
41-
_update(n)
42-
43-
pbar.update = warmup_update
44-
return pbar
45-
17+
from comfy.utils import model_trange as trange
4618

4719
def append_zero(x):
4820
return torch.cat([x, x.new_zeros([1])])

comfy/text_encoders/ace15.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from comfy import sd1_clip
44
import torch
55
import math
6-
from tqdm.auto import trange
76
import yaml
87
import comfy.utils
98

@@ -52,7 +51,7 @@ def sample_manual_loop_no_classes(
5251

5352
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
5453

55-
for step in trange(max_new_tokens, desc="LM sampling"):
54+
for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
5655
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
5756
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
5857
past_key_values = outputs[2]

comfy/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import logging
2828
import itertools
2929
from torch.nn.functional import interpolate
30+
from tqdm.auto import trange
3031
from einops import rearrange
3132
from comfy.cli_args import args, enables_dynamic_vram
3233
import json
@@ -1155,6 +1156,32 @@ def mult_list_upscale(a):
11551156
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
11561157
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
11571158

1159+
def model_trange(*args, **kwargs):
1160+
if comfy.memory_management.aimdo_allocator is None:
1161+
return trange(*args, **kwargs)
1162+
1163+
pbar = trange(*args, **kwargs, smoothing=1.0)
1164+
pbar._i = 0
1165+
pbar.set_postfix_str(" Model Initializing ... ")
1166+
1167+
_update = pbar.update
1168+
1169+
def warmup_update(n=1):
1170+
pbar._i += 1
1171+
if pbar._i == 1:
1172+
pbar.i1_time = time.time()
1173+
pbar.set_postfix_str(" Model Initialization complete! ")
1174+
elif pbar._i == 2:
1175+
#bring forward the effective start time based the the diff between first and second iteration
1176+
#to attempt to remove load overhead from the final step rate estimate.
1177+
pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
1178+
pbar.set_postfix_str("")
1179+
1180+
_update(n)
1181+
1182+
pbar.update = warmup_update
1183+
return pbar
1184+
11581185
PROGRESS_BAR_ENABLED = True
11591186
def set_progress_bar_enabled(enabled):
11601187
global PROGRESS_BAR_ENABLED

0 commit comments

Comments
 (0)