Skip to content

Commit 0bfb936

Browse files
authored
comfy-aimdo 0.2 - Improved pytorch allocator integration (Comfy-Org#12557)
Integrate comfy-aimdo 0.2 which takes a different approach to installing the memory allocator hook. Instead of using the complicated and buggy pytorch MemPool+CudaPluggableAlloctor, cuda is directly hooked making the process much more transparent to both comfy and pytorch. As far as pytorch knows, aimdo doesnt exist anymore, and just operates behind the scenes. Remove all the mempool setup stuff for dynamic_vram and bump the comfy-aimdo version. Remove the allocator object from memory_management and demote its use as an enablment check to a boolean flag. Comfy-aimdo 0.2 also support the pytorch cuda async allocator, so remove the dynamic_vram based force disablement of cuda_malloc and just go back to the old settings of allocators based on command line input.
1 parent 602b250 commit 0bfb936

File tree

7 files changed

+18
-32
lines changed

7 files changed

+18
-32
lines changed

comfy/memory_management.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,4 @@ def interpret_gathered_like(tensors, gathered):
7878

7979
return dest_views
8080

81-
aimdo_allocator = None
81+
aimdo_enabled = False

comfy/model_management.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ def unet_inital_load_device(parameters, dtype):
836836

837837
mem_dev = get_free_memory(torch_dev)
838838
mem_cpu = get_free_memory(cpu_dev)
839-
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_allocator is None:
839+
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
840840
return torch_dev
841841
else:
842842
return cpu_dev
@@ -1121,7 +1121,6 @@ def get_cast_buffer(offload_stream, device, size, ref):
11211121
synchronize()
11221122
del STREAM_CAST_BUFFERS[offload_stream]
11231123
del cast_buffer
1124-
#FIXME: This doesn't work in Aimdo because mempool cant clear cache
11251124
soft_empty_cache()
11261125
with wf_context:
11271126
cast_buffer = torch.empty((size), dtype=torch.int8, device=device)

comfy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1154,7 +1154,7 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am
11541154
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
11551155

11561156
def model_trange(*args, **kwargs):
1157-
if comfy.memory_management.aimdo_allocator is None:
1157+
if not comfy.memory_management.aimdo_enabled:
11581158
return trange(*args, **kwargs)
11591159

11601160
pbar = trange(*args, **kwargs, smoothing=1.0)

cuda_malloc.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import os
22
import importlib.util
3-
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
3+
from comfy.cli_args import args, PerformanceFeature
44
import subprocess
55

6-
import comfy_aimdo.control
7-
86
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
97
def get_gpu_names():
108
if os.name == 'nt':
@@ -87,10 +85,6 @@ def cuda_malloc_supported():
8785
except:
8886
pass
8987

90-
if enables_dynamic_vram() and comfy_aimdo.control.init():
91-
args.cuda_malloc = False
92-
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = ""
93-
9488
if args.disable_cuda_malloc:
9589
args.cuda_malloc = False
9690

execution.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from enum import Enum
1010
from typing import List, Literal, NamedTuple, Optional, Union
1111
import asyncio
12-
from contextlib import nullcontext
1312

1413
import torch
1514

@@ -521,19 +520,14 @@ def pre_execute_cb(call_index):
521520
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
522521
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
523522

524-
#Do comfy_aimdo mempool chunking here on the per-node level. Multi-model workflows
525-
#will cause all sorts of incompatible memory shapes to fragment the pytorch alloc
526-
#that we just want to cull out each model run.
527-
allocator = comfy.memory_management.aimdo_allocator
528-
with nullcontext() if allocator is None else torch.cuda.use_mem_pool(torch.cuda.MemPool(allocator.allocator())):
529-
try:
530-
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
531-
finally:
532-
if allocator is not None:
533-
if args.verbose == "DEBUG":
534-
comfy_aimdo.model_vbar.vbars_analyze()
535-
comfy.model_management.reset_cast_buffers()
536-
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
523+
try:
524+
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
525+
finally:
526+
if comfy.memory_management.aimdo_enabled:
527+
if args.verbose == "DEBUG":
528+
comfy_aimdo.control.analyze()
529+
comfy.model_management.reset_cast_buffers()
530+
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
537531

538532
if has_pending_tasks:
539533
pending_async_nodes[unique_id] = output_data

main.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def execute_script(script_path):
173173
if 'torch' in sys.modules:
174174
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
175175

176+
import comfy_aimdo.control
177+
178+
if enables_dynamic_vram():
179+
comfy_aimdo.control.init()
176180

177181
import comfy.utils
178182

@@ -188,13 +192,9 @@ def execute_script(script_path):
188192
import comfy.memory_management
189193
import comfy.model_patcher
190194

191-
import comfy_aimdo.control
192-
import comfy_aimdo.torch
193-
194195
if enables_dynamic_vram():
195196
if comfy.model_management.torch_version_numeric < (2, 8):
196197
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
197-
comfy.memory_management.aimdo_allocator = None
198198
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
199199
if args.verbose == 'DEBUG':
200200
comfy_aimdo.control.set_log_debug()
@@ -208,11 +208,10 @@ def execute_script(script_path):
208208
comfy_aimdo.control.set_log_info()
209209

210210
comfy.model_patcher.CoreModelPatcher = comfy.model_patcher.ModelPatcherDynamic
211-
comfy.memory_management.aimdo_allocator = comfy_aimdo.torch.get_torch_allocator()
211+
comfy.memory_management.aimdo_enabled = True
212212
logging.info("DynamicVRAM support detected and enabled")
213213
else:
214214
logging.warning("No working comfy-aimdo install detected. DynamicVRAM support disabled. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
215-
comfy.memory_management.aimdo_allocator = None
216215

217216

218217
def cuda_malloc_warning():

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ alembic
2222
SQLAlchemy
2323
av>=14.2.0
2424
comfy-kitchen>=0.2.7
25-
comfy-aimdo>=0.1.8
25+
comfy-aimdo>=0.2.0
2626
requests
2727

2828
#non essential dependencies:

0 commit comments

Comments
 (0)