Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hooks Part 2 - TransformerOptionsHook and AdditionalModelsHook #6377

Merged
merged 21 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
72bbf49
Add 'sigmas' to transformer_options so that downstream code can know …
Kosinkadink Dec 29, 2024
bf21be0
Merge branch 'master' into hooks_part2
Kosinkadink Dec 30, 2024
d44295e
Merge branch 'master' into hooks_part2
Kosinkadink Jan 4, 2025
5a2ad03
Cleaned up hooks.py, refactored Hook.should_register and add_hook_pat…
Kosinkadink Jan 4, 2025
776aa73
Refactor WrapperHook into TransformerOptionsHook, as there is no need…
Kosinkadink Jan 4, 2025
111fd0c
Refactored HookGroup to also store a dictionary of hooks separated by…
Kosinkadink Jan 4, 2025
6620d86
In inner_sample, change "sigmas" to "sampler_sigmas" in transformer_o…
Kosinkadink Jan 5, 2025
db2d7ad
Merge branch 'add_sample_sigmas' into hooks_part2
Kosinkadink Jan 5, 2025
8270ff3
Refactored 'registered' to be HookGroup instead of a list of Hooks, m…
Kosinkadink Jan 6, 2025
4446c86
Made hook clone code sane, made clear ObjectPatchHook and SetInjectio…
Kosinkadink Jan 6, 2025
03a97b6
Fix performance of hooks when hooks are appended via Cond Pair Set Pr…
Kosinkadink Jan 6, 2025
0a7e2ae
Filter only registered hooks on self.conds in CFGGuider.sample
Kosinkadink Jan 6, 2025
6463c39
Merge branch 'master' into hooks_part2
Kosinkadink Jan 6, 2025
f48f90e
Make hook_scope functional for TransformerOptionsHook
Kosinkadink Jan 6, 2025
2724ac4
Merge branch 'master' into hooks_part2
Kosinkadink Jan 6, 2025
1b38f5b
removed 4 whitespace lines to satisfy Ruff,
Kosinkadink Jan 6, 2025
58bf881
Add a get_injections function to ModelPatcher
Kosinkadink Jan 7, 2025
216fea1
Made TransformerOptionsHook contribute to registered hooks properly, …
Kosinkadink Jan 7, 2025
11c6d56
Merge branch 'master' into hooks_part2
Kosinkadink Jan 7, 2025
3cd4c5c
Rename AddModelsHooks to AdditionalModelsHook, rename SetInjectionsHo…
Kosinkadink Jan 7, 2025
7333281
Clean up a typehint
Kosinkadink Jan 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Make hook_scope functional for TransformerOptionsHook
  • Loading branch information
Kosinkadink committed Jan 6, 2025
commit f48f90e471fc5440135e7886d712518467c59c00
41 changes: 26 additions & 15 deletions comfy/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def __init__(self, hook_type: EnumHookType=None, hook_ref: _HookRef=None, hook_i
self.hook_ref = hook_ref if hook_ref else _HookRef()
self.hook_id = hook_id
self.hook_keyframe = hook_keyframe if hook_keyframe else HookKeyframeGroup()
self.hook_scope = hook_scope
self.custom_should_register = default_should_register
self.auto_apply_to_nonpositive = False
self.hook_scope = hook_scope

@property
def strength(self):
Expand All @@ -107,6 +107,7 @@ def clone(self):
c.hook_ref = self.hook_ref
c.hook_id = self.hook_id
c.hook_keyframe = self.hook_keyframe
c.hook_scope = self.hook_scope
c.custom_should_register = self.custom_should_register
# TODO: make this do something
c.auto_apply_to_nonpositive = self.auto_apply_to_nonpositive
Expand All @@ -118,12 +119,6 @@ def should_register(self, model: ModelPatcher, model_options: dict, target_dict:
def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
raise NotImplementedError("add_hook_patches should be defined for Hook subclasses")

def on_apply(self, model: ModelPatcher, transformer_options: dict[str]):
pass

def on_unapply(self, model: ModelPatcher, transformer_options: dict[str]):
pass

def __eq__(self, other: Hook):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref

Expand All @@ -143,6 +138,7 @@ def __init__(self, strength_model=1.0, strength_clip=1.0):
self.need_weight_init = True
self._strength_model = strength_model
self._strength_clip = strength_clip
self.hook_scope = EnumHookScope.HookedOnly # this value does not matter for WeightHooks, just for docs

@property
def strength_model(self):
Expand Down Expand Up @@ -190,9 +186,11 @@ def clone(self):
return c

class ObjectPatchHook(Hook):
def __init__(self, object_patches: dict[str]=None):
def __init__(self, object_patches: dict[str]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.ObjectPatch)
self.object_patches = object_patches
self.hook_scope = hook_scope

def clone(self):
c: ObjectPatchHook = super().clone()
Expand All @@ -216,14 +214,11 @@ def __init__(self, models: list[ModelPatcher]=None, key: str=None):
super().__init__(hook_type=EnumHookType.AddModels)
self.models = models
self.key = key
self.append_when_same = True
'''Curently does nothing.'''

def clone(self):
c: AddModelsHook = super().clone()
c.models = self.models.copy() if self.models else self.models
c.key = self.key
c.append_when_same = self.append_when_same
return c

def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict: dict[str], registered: HookGroup):
Expand All @@ -236,9 +231,11 @@ class TransformerOptionsHook(Hook):
'''
Hook responsible for adding wrappers, callbacks, patches, or anything else related to transformer_options.
'''
def __init__(self, wrappers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None):
def __init__(self, transformers_dict: dict[str, dict[str, dict[str, list[Callable]]]]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.TransformerOptions)
self.transformers_dict = wrappers_dict
self.transformers_dict = transformers_dict
self.hook_scope = hook_scope

def clone(self):
c: TransformerOptionsHook = super().clone()
Expand All @@ -254,8 +251,9 @@ def add_hook_patches(self, model: ModelPatcher, model_options: dict, target_dict
"to_load_options": self.transformers_dict}
else:
add_model_options = {"to_load_options": self.transformers_dict}
# only register if will not be included in AllConditioning to avoid double loading
registered.add(self)
comfy.patcher_extension.merge_nested_dicts(model_options, add_model_options, copy_dict1=False)
registered.add(self)
return True

def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
Expand All @@ -265,10 +263,12 @@ def on_apply_hooks(self, model: ModelPatcher, transformer_options: dict[str]):
'''Only here for backwards compatibility, WrapperHook is identical to TransformerOptionsHook.'''

class SetInjectionsHook(Hook):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None):
def __init__(self, key: str=None, injections: list[PatcherInjection]=None,
hook_scope=EnumHookScope.AllConditioning):
super().__init__(hook_type=EnumHookType.Injections)
self.key = key
self.injections = injections
self.hook_scope = hook_scope

def clone(self):
c: SetInjectionsHook = super().clone()
Expand Down Expand Up @@ -590,6 +590,17 @@ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
sorted_list.extend(object_list)
return sorted_list

def create_transformer_options_from_hooks(model: ModelPatcher, hooks: HookGroup, transformer_options: dict[str]=None):
# if no hooks or is not a ModelPatcher for sampling, return empty dict
if hooks is None or model.is_clip:
return {}
if transformer_options is None:
transformer_options = {}
for hook in hooks.get_type(EnumHookType.TransformerOptions):
hook: TransformerOptionsHook
hook.on_apply_hooks(model, transformer_options)
return transformer_options

def create_hook_lora(lora: dict[str, torch.Tensor], strength_model: float, strength_clip: float):
hook_group = HookGroup()
hook = WeightHook(strength_model=strength_model, strength_clip=strength_clip)
Expand Down
4 changes: 2 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,11 +1010,11 @@ def get_combined_hook_patches(self, hooks: comfy.hooks.HookGroup):
def apply_hooks(self, hooks: comfy.hooks.HookGroup, transformer_options: dict=None, force_apply=False):
# TODO: return transformer_options dict with any additions from hooks
if self.current_hooks == hooks and (not force_apply or (not self.is_clip and hooks is None)):
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)
self.patch_hooks(hooks=hooks)
for callback in self.get_all_callbacks(CallbacksMP.ON_APPLY_HOOKS):
callback(self, hooks)
return {}
return comfy.hooks.create_transformer_options_from_hooks(self, hooks, transformer_options)

def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
Expand Down
Loading