Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions comfy/extra_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"unload_text_encoder_after_run": true
}
28 changes: 23 additions & 5 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

import gc
import torch
import psutil
import logging
from enum import Enum
Expand Down Expand Up @@ -1085,10 +1086,9 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None or weight.dtype == dtype:
return weight
if stream is not None:
wf_context = stream
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
with wf_context:
if not hasattr(stream, "__enter__"):
logging.error(f"Stream object {stream} of type {type(stream)} does not have __enter__ method")
with stream:
return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)

Expand Down Expand Up @@ -1552,3 +1552,21 @@ def throw_exception_if_processing_interrupted():
if interrupt_processing:
interrupt_processing = False
raise InterruptProcessingException()

def cleanup_ram():
gc.collect()
try:
torch.cuda.empty_cache()
except:
pass
def unload_text_encoder(encoder):
if encoder is None:
return
try:
if hasattr(encoder, "model"):
del encoder.model
del encoder
except:
pass
cleanup_ram()

25 changes: 25 additions & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,31 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=
"outputs": ui_outputs,
"meta": meta_outputs,
}

try:
import comfy.model_management as mm

# If ComfyUI exposes loaded text encoders (most builds do)
if hasattr(mm, "loaded_text_encoders"):
for enc in list(mm.loaded_text_encoders.values()):
try:
mm.unload_text_encoder(enc)
except:
pass

mm.loaded_text_encoders.clear()

# Final RAM + VRAM cleanup
try:
mm.cleanup_models_gc()
except:
pass

print("[RAM Optimizer] Text encoders unloaded successfully after run.")
except Exception as e:
print(f"[RAM Optimizer] Failed to unload text encoders: {e}")
# --- END: Text Encoder RAM Cleanup Patch ---

self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
Expand Down
4 changes: 2 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,7 +2247,7 @@ async def init_external_custom_nodes():
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
continue

if args.enable_manager:
if getattr(args, "enable_manager", False):
if comfyui_manager.should_be_disabled(module_path):
logging.info(f"Blocked by policy: {module_path}")
continue
Expand Down Expand Up @@ -2449,4 +2449,4 @@ async def init_extra_nodes(init_custom_nodes=True, init_api_nodes=True):
logging.warning("Please do a: pip install -r requirements.txt")
logging.warning("")

return import_failed
return import_failed
27 changes: 27 additions & 0 deletions reproduce_stream_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

import torch
import logging

logging.basicConfig(level=logging.INFO)

def test_stream():
if not torch.cuda.is_available():
print("CUDA not available, cannot test cuda stream")
return

device = torch.device("cuda")
stream = torch.cuda.Stream(device=device, priority=0)

print(f"Stream type: {type(stream)}")
print(f"Has __enter__: {hasattr(stream, '__enter__')}")

try:
with stream:
print("Stream context manager works")
except AttributeError as e:
print(f"AttributeError caught: {e}")
except Exception as e:
print(f"Other exception caught: {e}")

if __name__ == "__main__":
test_stream()