Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hooks Part 2 - TransformerOptionsHook and AdditionalModelsHook #6377

Merged
merged 21 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
72bbf49
Add 'sigmas' to transformer_options so that downstream code can know …
Kosinkadink Dec 29, 2024
bf21be0
Merge branch 'master' into hooks_part2
Kosinkadink Dec 30, 2024
d44295e
Merge branch 'master' into hooks_part2
Kosinkadink Jan 4, 2025
5a2ad03
Cleaned up hooks.py, refactored Hook.should_register and add_hook_pat…
Kosinkadink Jan 4, 2025
776aa73
Refactor WrapperHook into TransformerOptionsHook, as there is no need…
Kosinkadink Jan 4, 2025
111fd0c
Refactored HookGroup to also store a dictionary of hooks separated by…
Kosinkadink Jan 4, 2025
6620d86
In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_o…
Kosinkadink Jan 5, 2025
db2d7ad
Merge branch 'add_sample_sigmas' into hooks_part2
Kosinkadink Jan 5, 2025
8270ff3
Refactored 'registered' to be HookGroup instead of a list of Hooks, m…
Kosinkadink Jan 6, 2025
4446c86
Made hook clone code sane, made clear ObjectPatchHook and SetInjectio…
Kosinkadink Jan 6, 2025
03a97b6
Fix performance of hooks when hooks are appended via Cond Pair Set Pr…
Kosinkadink Jan 6, 2025
0a7e2ae
Filter only registered hooks on self.conds in CFGGuider.sample
Kosinkadink Jan 6, 2025
6463c39
Merge branch 'master' into hooks_part2
Kosinkadink Jan 6, 2025
f48f90e
Make hook_scope functional for TransformerOptionsHook
Kosinkadink Jan 6, 2025
2724ac4
Merge branch 'master' into hooks_part2
Kosinkadink Jan 6, 2025
1b38f5b
removed 4 whitespace lines to satisfy Ruff,
Kosinkadink Jan 6, 2025
58bf881
Add a get_injections function to ModelPatcher
Kosinkadink Jan 7, 2025
216fea1
Made TransformerOptionsHook contribute to registered hooks properly, …
Kosinkadink Jan 7, 2025
11c6d56
Merge branch 'master' into hooks_part2
Kosinkadink Jan 7, 2025
3cd4c5c
Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHo…
Kosinkadink Jan 7, 2025
7333281
Clean up a typehint
Kosinkadink Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Cleaned up hooks.py, refactored Hook.should_register and add_hook_pat…
…ches to use target_dict instead of target so that more information can be provided about the current execution environment if needed
  • Loading branch information
Kosinkadink committed Jan 4, 2025
commit 5a2ad032cb09afcaf7fadf5cdfa20c2b0498aee5
148 changes: 96 additions & 52 deletions comfy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,86 @@
import comfy.patcher_extension
from node_helpers import conditioning_set_values

# #######################################################################################################
# Hooks explanation
# -------------------
# The purpose of hooks is to allow conds to influence sampling without the need for ComfyUI core code to
# make explicit special cases like it does for ControlNet and GLIGEN.
#
# This is necessary for nodes/features that are intended for use with masked or scheduled conds, or those
# that should run special code when a 'marked' cond is used in sampling.
# #######################################################################################################

class EnumHookMode(enum.Enum):
'''
Priority of hook memory optimization vs. speed, mostly related to WeightHooks.

MinVram: No caching will occur for any operations related to hooks.
MaxSpeed: Excess VRAM (and RAM, once VRAM is sufficiently depleted) will be used to cache hook weights when switching hook groups.
'''
MinVram = "minvram"
MaxSpeed = "maxspeed"

class EnumHookType(enum.Enum):
'''
Hook types, each of which has different expected behavior.
'''
Weight = "weight"
Patch = "patch"
ObjectPatch = "object_patch"
AddModels = "add_models"
Callbacks = "callbacks"
Wrappers = "wrappers"
SetInjections = "add_injections"
Injections = "add_injections"

class EnumWeightTarget(enum.Enum):
Model = "model"
Clip = "clip"

class EnumHookScope(enum.Enum):
'''
Determines if hook should be limited in its influence over sampling.

AllConditioning: hook will affect all conds used in sampling.
HookedOnly: hook will only affect the conds it was attached to.
'''
AllConditioning = "all_conditioning"
HookedOnly = "hooked_only"


class _HookRef:
pass

# NOTE: this is an example of how the should_register function should look
def default_should_register(hook: 'Hook', model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):

def default_should_register(hook: Hook, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
'''Example for how should_register function should look like.'''
return True


def create_target_dict(target: EnumWeightTarget=None, **kwargs) -> dict[str]:
'''Creates base dictionary for use with Hooks' target param.'''
d = {}
if target is not None:
d['target'] = target
d.update(kwargs)
return d


class Hook:
def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_id: str=None,
hook_keyframe: 'HookKeyframeGroup'=None):
hook_keyframe: HookKeyframeGroup=None, hook_scope=EnumHookScope.AllConditioning):
self.hook_type = hook_type
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
self.hook_scope = hook_scope

@property
def strength(self):
return self.hook_keyframe.strength

def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
self.reset()
self.hook_keyframe.initialize_timesteps(model)

Expand All @@ -75,27 +115,32 @@ def clone(self, subtype: Callable=None):
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
return c

def should_register(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target, registered)
def should_register(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
return self.custom_should_register(self, model, model_options, target_dict, registered)

def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")

def on_apply(self, model: 'ModelPatcher', transformer_options: dict[str]):
def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
pass

def on_unapply(self, model: 'ModelPatcher', transformer_options: dict[str]):
def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]):
pass

def __eq__(self, other: 'Hook'):
def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref

def __hash__(self):
return hash(self.hook_ref)

class WeightHook(Hook):
'''
Hook responsible for tracking weights to be applied to some model/clip.

Note, value of hook_scope is ignored and is treated as HookedOnly.
'''
def __init__(self, strength_model=1.0, strength_clip=1.0):
super().__init__(hook_type=EnumHookType.Weight)
super().__init__(hook_type=EnumHookType.Weight, hook_scope=EnumHookScope.HookedOnly)
self.weights: dict = None
self.weights_clip: dict = None
self.need_weight_init = True
Expand All @@ -110,27 +155,29 @@ def strength_model(self):
def strength_clip(self):
return self._strength_clip * self.strength

def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered):
return False
weights = None
if target == EnumWeightTarget.Model:
strength = self._strength_model
else:

target = target_dict.get('target', None)
if target == EnumWeightTarget.Clip:
strength = self._strength_clip
else:
strength = self._strength_model

if self.need_weight_init:
key_map = {}
if target == EnumWeightTarget.Model:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
else:
if target == EnumWeightTarget.Clip:
key_map = comfy.lora.model_lora_keys_clip(model.model, key_map)
else:
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
weights = comfy.lora.load_lora(self.weights, key_map, log_missing=False)
else:
if target == EnumWeightTarget.Model:
weights = self.weights
else:
if target == EnumWeightTarget.Clip:
weights = self.weights_clip
else:
weights = self.weights
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self)
return True
Expand Down Expand Up @@ -174,7 +221,12 @@ def clone(self, subtype: Callable=None):
# TODO: add functionality

class AddModelsHook(Hook):
def __init__(self, key: str=None, models: list['ModelPatcher']=None):
'''
Hook responsible for telling model management any additional models that should be loaded.

Note, value of hook_scope is ignored and is treated as AllConditioning.
'''
def __init__(self, key: str=None, models: list[ModelPatcher]=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.key = key
self.models = models
Expand All @@ -188,24 +240,15 @@ def clone(self, subtype: Callable=None):
c.models = self.models.copy() if self.models else self.models
c.append_when_same = self.append_when_same
return c
# TODO: add functionality

class CallbackHook(Hook):
def __init__(self, key: str=None, callback: Callable=None):
super().__init__(hook_type=EnumHookType.Callbacks)
self.key = key
self.callback = callback

def clone(self, subtype: Callable=None):
if subtype is None:
subtype = type(self)
c: CallbackHook = super().clone(subtype)
c.key = self.key
c.callback = self.callback
return c
# TODO: add functionality
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered):
return False

class WrapperHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, or anything else onto transformer_options.
'''
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
super().__init__(hook_type=EnumHookType.Wrappers)
self.wrappers_dict = wrappers_dict
Expand All @@ -217,17 +260,18 @@ def clone(self, subtype: Callable=None):
c.wrappers_dict = self.wrappers_dict
return c

def add_hook_patches(self, model: 'ModelPatcher', model_options: dict, target: EnumWeightTarget, registered: list[Hook]):
if not self.should_register(model, model_options, target, registered):
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: list[Hook]):
if not self.should_register(model, model_options, target_dict, registered):
return False
add_model_options = {"transformer_options": self.wrappers_dict}
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
if self.hook_scope == EnumHookScope.AllConditioning:
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.append(self)
return True

class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list['PatcherInjection']=None):
super().__init__(hook_type=EnumHookType.SetInjections)
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections

Expand All @@ -239,7 +283,7 @@ def clone(self, subtype: Callable=None):
c.injections = self.injections.copy() if self.injections else self.injections
return c

def add_hook_injections(self, model: 'ModelPatcher'):
def add_hook_injections(self, model: ModelPatcher):
# TODO: add functionality
pass

Expand All @@ -260,14 +304,14 @@ def clone(self):
c.add(hook.clone())
return c

def clone_and_combine(self, other: 'HookGroup'):
def clone_and_combine(self, other: HookGroup):
c = self.clone()
if other is not None:
for hook in other.hooks:
c.add(hook.clone())
return c

def set_keyframes_on_hooks(self, hook_kf: 'HookKeyframeGroup'):
def set_keyframes_on_hooks(self, hook_kf: HookKeyframeGroup):
if hook_kf is None:
hook_kf = HookKeyframeGroup()
else:
Expand Down Expand Up @@ -336,7 +380,7 @@ def reset(self):
hook.reset()

@staticmethod
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
def combine_all_hooks(hooks_list: list[HookGroup], require_count=0) -> HookGroup:
actual: list[HookGroup] = []
for group in hooks_list:
if group is not None:
Expand Down Expand Up @@ -433,7 +477,7 @@ def clone(self):
c._set_first_as_current()
return c

def initialize_timesteps(self, model: 'BaseModel'):
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)

Expand Down Expand Up @@ -548,7 +592,7 @@ def create_hook_model_as_lora(weights_model, weights_clip, strength_model: float
hook.need_weight_init = False
return hook_group

def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=True):
def get_patch_weights_from_model(model: ModelPatcher, discard_model_sampling=True):
if model is None:
return None
patches_model: dict[str, torch.Tensor] = model.model.state_dict()
Expand All @@ -560,7 +604,7 @@ def get_patch_weights_from_model(model: 'ModelPatcher', discard_model_sampling=T
return patches_model

# NOTE: this function shows how to register weight hooks directly on the ModelPatchers
def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[str, torch.Tensor],
def load_hook_lora_for_models(model: ModelPatcher, clip: CLIP, lora: dict[str, torch.Tensor],
strength_model: float, strength_clip: float):
key_map = {}
if model is not None:
Expand Down
8 changes: 4 additions & 4 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,13 +940,13 @@ def prepare_hook_patches_current_keyframe(self, t: torch.Tensor, hook_group: com
if reset_current_hooks:
self.patch_hooks(None)

def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target: comfy.hooks.EnumWeightTarget, model_options: dict=None):
def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, dict[comfy.hooks.Hook, None]], target_dict: dict[str], model_options: dict=None):
self.restore_hook_patches()
registered_hooks: list[comfy.hooks.Hook] = []
# handle WrapperHooks, if model_options provided
if model_options is not None:
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Wrappers, {}):
hook.add_hook_patches(self, model_options, target, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
# handle WeightHooks
weight_hooks_to_register: list[comfy.hooks.WeightHook] = []
for hook in hooks_dict.get(comfy.hooks.EnumHookType.Weight, {}):
Expand All @@ -956,9 +956,9 @@ def register_all_hook_patches(self, hooks_dict: dict[comfy.hooks.EnumHookType, d
# clone hook_patches to become backup so that any non-dynamic hooks will return to their original state
self.hook_patches_backup = create_hook_patches_clone(self.hook_patches)
for hook in weight_hooks_to_register:
hook.add_hook_patches(self, model_options, target, registered_hooks)
hook.add_hook_patches(self, model_options, target_dict, registered_hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_REGISTER_ALL_HOOK_PATCHES):
callback(self, hooks_dict, target)
callback(self, hooks_dict, target_dict)

def add_hook_patches(self, hook: comfy.hooks.WeightHook, patches, strength_patch=1.0, strength_model=1.0):
with self.use_ejected():
Expand Down
2 changes: 1 addition & 1 deletion comfy/sampler_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ def prepare_model_patcher(model: 'ModelPatcher', conds, model_options: dict):
model_options["transformer_options"]["wrappers"] = comfy.patcher_extension.copy_nested_dicts(model.wrappers)
model_options["transformer_options"]["callbacks"] = comfy.patcher_extension.copy_nested_dicts(model.callbacks)
# register hooks on model/model_options
model.register_all_hook_patches(hooks, comfy.hooks.EnumWeightTarget.Model, model_options)
model.register_all_hook_patches(hooks, comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Model), model_options)
2 changes: 1 addition & 1 deletion comfy_extras/nodes_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def apply_hooks(self, clip: 'CLIP', schedule_clip: bool, apply_to_conds: bool, h
clip.use_clip_schedule = schedule_clip
if not clip.use_clip_schedule:
clip.patcher.forced_hooks.set_keyframes_on_hooks(None)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.EnumWeightTarget.Clip)
clip.patcher.register_all_hook_patches(hooks.get_dict_repr(), comfy.hooks.create_target_dict(comfy.hooks.EnumWeightTarget.Clip))
return (clip,)

class ConditioningTimestepsRange:
Expand Down