Skip to content

Commit d297a74

Browse files
authored
dynamic_vram: Fix windows Aimdo crash + Fix LLM performance (Comfy-Org#12408)
* model_management: lazy-cache aimdo_tensor These tensors cosntructed from aimdo-allocations are CPU expensive to make on the pytorch side. Add a cache version that will be valid with signature match to fast path past whatever torch is doing. * dynamic_vram: Minimize fast path CPU work Move as much as possible inside the not resident if block and cache the formed weight and bias rather than the flat intermediates. In extreme layer weight rates this adds up.
1 parent 2b7cc7e commit d297a74

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

comfy/model_management.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
12131213

12141214
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
12151215
if signature is not None:
1216-
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
1217-
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
1216+
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
1217+
v_tensor = weight._v_tensor
1218+
else:
1219+
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
1220+
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
1221+
weight._v_tensor = v_tensor
12181222
weight._v_signature = signature
12191223
#Send it over
12201224
v_tensor.copy_(weight, non_blocking=non_blocking)

comfy/model_patcher.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,7 +1542,6 @@ def setup_param(self, m, n, param_key):
15421542

15431543
if vbar is not None and not hasattr(m, "_v"):
15441544
m._v = vbar.alloc(v_weight_size)
1545-
m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
15461545
allocated_size += v_weight_size
15471546

15481547
else:
@@ -1557,7 +1556,6 @@ def setup_param(self, m, n, param_key):
15571556
weight_size = geometry.numel() * geometry.element_size()
15581557
if vbar is not None and not hasattr(weight, "_v"):
15591558
weight._v = vbar.alloc(weight_size)
1560-
weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
15611559
weight._model_dtype = model_dtype
15621560
allocated_size += weight_size
15631561
vbar.set_watermark_limit(allocated_size)

comfy/ops.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
8383
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
8484
offload_stream = None
8585
xfer_dest = None
86-
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
8786

8887
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
89-
if signature is not None:
90-
xfer_dest = s._v_tensor
9188
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
89+
if signature is not None:
90+
if resident:
91+
weight = s._v_weight
92+
bias = s._v_bias
93+
else:
94+
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
9295

9396
if not resident:
97+
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
9498
cast_dest = None
9599

96100
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
140144
post_cast.copy_(pre_cast)
141145
xfer_dest = cast_dest
142146

143-
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
144-
weight = params[0]
145-
bias = params[1]
147+
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
148+
weight = params[0]
149+
bias = params[1]
150+
if signature is not None:
151+
s._v_weight = weight
152+
s._v_bias = bias
153+
s._v_signature=signature
146154

147155
def post_cast(s, param_key, x, dtype, resident, update_weight):
148156
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -182,7 +190,6 @@ def to_dequant(tensor, dtype):
182190
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
183191
if s.bias is not None:
184192
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
185-
s._v_signature=signature
186193

187194
#FIXME: weird offload return protocol
188195
return weight, bias, (offload_stream, device if signature is not None else None, None)

0 commit comments

Comments
 (0)