Skip to content

Commit 3ebe1ac

Browse files
authored
Disable dynamic_vram when using torch compiler (Comfy-Org#12612)
* mp: attach re-construction arguments to model patcher When making a model-patcher from a unet or ckpt, attach a callable function that can be called to replay the model construction. This can be used to deep clone model patcher WRT the actual model. Originally written by Kosinkadink Comfy-Org@f4b99bc * mp: Add disable_dynamic clone argument Add a clone argument that lets a caller clone a ModelPatcher but disable dynamic to demote the clone to regular MP. This is useful for legacy features where dynamic_vram support is missing or TBD. * torch_compile: disable dynamic_vram This is a bigger feature. Disable for the interim to preserve functionality.
1 parent befa83d commit 3ebe1ac

File tree

3 files changed

+34
-11
lines changed

3 files changed

+34
-11
lines changed

comfy/model_patcher.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
271271
self.is_clip = False
272272
self.hook_mode = comfy.hooks.EnumHookMode.MaxSpeed
273273

274+
self.cached_patcher_init: tuple[Callable, tuple] | None = None
274275
if not hasattr(self.model, 'model_loaded_weight_memory'):
275276
self.model.model_loaded_weight_memory = 0
276277

@@ -307,8 +308,15 @@ def lowvram_patch_counter(self):
307308
def get_free_memory(self, device):
308309
return comfy.model_management.get_free_memory(device)
309310

310-
def clone(self):
311-
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
311+
def clone(self, disable_dynamic=False):
312+
class_ = self.__class__
313+
model = self.model
314+
if self.is_dynamic() and disable_dynamic:
315+
class_ = ModelPatcher
316+
temp_model_patcher = self.cached_patcher_init[0](*self.cached_patcher_init[1], disable_dynamic=True)
317+
model = temp_model_patcher.model
318+
319+
n = class_(model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
312320
n.patches = {}
313321
for k in self.patches:
314322
n.patches[k] = self.patches[k][:]
@@ -362,6 +370,8 @@ def clone(self):
362370
n.is_clip = self.is_clip
363371
n.hook_mode = self.hook_mode
364372

373+
n.cached_patcher_init = self.cached_patcher_init
374+
365375
for callback in self.get_all_callbacks(CallbacksMP.ON_CLONE):
366376
callback(self, n)
367377
return n

comfy/sd.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1530,14 +1530,24 @@ class ModelSamplingAdvanced(comfy.model_sampling.ModelSamplingDiscrete, comfy.mo
15301530

15311531
return (model, clip, vae)
15321532

1533-
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
1533+
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
15341534
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
1535-
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
1535+
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
15361536
if out is None:
15371537
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(ckpt_path, model_detection_error_hint(ckpt_path, sd)))
1538+
if output_model:
1539+
out[0].cached_patcher_init = (load_checkpoint_guess_config_model_only, (ckpt_path, embedding_directory, model_options, te_model_options))
15381540
return out
15391541

1540-
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
1542+
def load_checkpoint_guess_config_model_only(ckpt_path, embedding_directory=None, model_options={}, te_model_options={}, disable_dynamic=False):
1543+
model, *_ = load_checkpoint_guess_config(ckpt_path, False, False, False,
1544+
embedding_directory=embedding_directory,
1545+
model_options=model_options,
1546+
te_model_options=te_model_options,
1547+
disable_dynamic=disable_dynamic)
1548+
return model
1549+
1550+
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None, disable_dynamic=False):
15411551
clip = None
15421552
clipvision = None
15431553
vae = None
@@ -1586,7 +1596,8 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
15861596
if output_model:
15871597
inital_load_device = model_management.unet_inital_load_device(parameters, unet_dtype)
15881598
model = model_config.get_model(sd, diffusion_model_prefix, device=inital_load_device)
1589-
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
1599+
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
1600+
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
15901601
model.load_model_weights(sd, diffusion_model_prefix, assign=model_patcher.is_dynamic())
15911602

15921603
if output_vae:
@@ -1637,7 +1648,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
16371648
return (model_patcher, clip, vae, clipvision)
16381649

16391650

1640-
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
1651+
def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable_dynamic=False):
16411652
"""
16421653
Loads a UNet diffusion model from a state dictionary, supporting both diffusers and regular formats.
16431654
@@ -1721,7 +1732,8 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
17211732
model_config.optimizations["fp8"] = True
17221733

17231734
model = model_config.get_model(new_sd, "")
1724-
model_patcher = comfy.model_patcher.CoreModelPatcher(model, load_device=load_device, offload_device=offload_device)
1735+
ModelPatcher = comfy.model_patcher.ModelPatcher if disable_dynamic else comfy.model_patcher.CoreModelPatcher
1736+
model_patcher = ModelPatcher(model, load_device=load_device, offload_device=offload_device)
17251737
if not model_management.is_device_cpu(offload_device):
17261738
model.to(offload_device)
17271739
model.load_model_weights(new_sd, "", assign=model_patcher.is_dynamic())
@@ -1730,12 +1742,13 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
17301742
logging.info("left over keys in diffusion model: {}".format(left_over))
17311743
return model_patcher
17321744

1733-
def load_diffusion_model(unet_path, model_options={}):
1745+
def load_diffusion_model(unet_path, model_options={}, disable_dynamic=False):
17341746
sd, metadata = comfy.utils.load_torch_file(unet_path, return_metadata=True)
1735-
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata)
1747+
model = load_diffusion_model_state_dict(sd, model_options=model_options, metadata=metadata, disable_dynamic=disable_dynamic)
17361748
if model is None:
17371749
logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path))
17381750
raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd)))
1751+
model.cached_patcher_init = (load_diffusion_model, (unet_path, model_options))
17391752
return model
17401753

17411754
def load_unet(unet_path, dtype=None):

comfy_extras/nodes_torch_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def define_schema(cls) -> io.Schema:
2525

2626
@classmethod
2727
def execute(cls, model, backend) -> io.NodeOutput:
28-
m = model.clone()
28+
m = model.clone(disable_dynamic=True)
2929
set_torch_compile_wrapper(model=m, backend=backend, options={"guard_filter_fn": skip_torch_compile_dict})
3030
return io.NodeOutput(m)
3131

0 commit comments

Comments
 (0)