Skip to content

Commit

Permalink
Added experimental node to inject PIA component into any AnimateDiff …
Browse files Browse the repository at this point in the history
…model (not very useful currently)
  • Loading branch information
Kosinkadink committed Jun 17, 2024
1 parent f01d1f2 commit 62b0dae
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 3 deletions.
8 changes: 8 additions & 0 deletions animatediff/model_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1262,6 +1262,14 @@ def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: M
motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict())


def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher):
motion_model.model.init_conv_in(w_pia.model.state_dict())
motion_model.model.conv_in.to(comfy.model_management.unet_dtype())
motion_model.model.conv_in.to(comfy.model_management.unet_offload_device())
motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict())
motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA


def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ctrl_name: str):
camera_ctrl_path = get_motion_model_path(camera_ctrl_name)
full_state_dict = comfy.utils.load_torch_file(camera_ctrl_path, safe_load=True)
Expand Down
4 changes: 3 additions & 1 deletion animatediff/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .nodes_cameractrl import (LoadAnimateDiffModelWithCameraCtrl, ApplyAnimateDiffWithCameraCtrl, CameraCtrlADKeyframeNode, LoadCameraPoses,
CameraCtrlPoseBasic, CameraCtrlPoseCombo, CameraCtrlPoseAdvanced, CameraCtrlManualAppendPose,
CameraCtrlReplaceCameraParameters, CameraCtrlSetOriginalAspectRatio)
from .nodes_pia import (ApplyAnimateDiffPIAModel, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode)
from .nodes_pia import (ApplyAnimateDiffPIAModel, LoadAnimateDiffAndInjectPIANode, InputPIA_MultivalNode, InputPIA_PaperPresetsNode, PIA_ADKeyframeNode)
from .nodes_multival import MultivalDynamicNode, MultivalScaledMaskNode, MultivalDynamicFloatInputNode, MultivalConvertToMaskNode
from .nodes_conditioning import (MaskableLoraLoader, MaskableLoraLoaderModelOnly, MaskableSDModelLoader, MaskableSDModelLoaderModelOnly,
SetModelLoraHook, SetClipLoraHook,
Expand Down Expand Up @@ -142,6 +142,7 @@
"ADE_InputPIA_Multival": InputPIA_MultivalNode,
"ADE_InputPIA_PaperPresets": InputPIA_PaperPresetsNode,
"ADE_PIA_AnimateDiffKeyframe": PIA_ADKeyframeNode,
"ADE_InjectPIAIntoAnimateDiffModel": LoadAnimateDiffAndInjectPIANode,
# Deprecated Nodes
"AnimateDiffLoaderV1": AnimateDiffLoader_Deprecated,
"ADE_AnimateDiffLoaderV1Advanced": AnimateDiffLoaderAdvanced_Deprecated,
Expand Down Expand Up @@ -254,6 +255,7 @@
"ADE_InputPIA_Multival": "PIA Input [Multival] 🎭🅐🅓②",
"ADE_InputPIA_PaperPresets": "PIA Input [Paper Presets] 🎭🅐🅓②",
"ADE_PIA_AnimateDiffKeyframe": "AnimateDiff-PIA Keyframe 🎭🅐🅓",
"ADE_InjectPIAIntoAnimateDiffModel": "🧪Inject PIA into AnimateDiff Model 🎭🅐🅓②",
# Deprecated Nodes
"AnimateDiffLoaderV1": "🚫AnimateDiff Loader [DEPRECATED] 🎭🅐🅓",
"ADE_AnimateDiffLoaderV1Advanced": "🚫AnimateDiff Loader (Advanced) [DEPRECATED] 🎭🅐🅓",
Expand Down
34 changes: 32 additions & 2 deletions animatediff/nodes_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@

from comfy.sd import VAE

from .ad_settings import AnimateDiffSettings
from .logger import logger
from .utils_model import BIGMIN, BIGMAX
from .utils_model import BIGMIN, BIGMAX, get_available_motion_models
from .utils_motion import ADKeyframeGroup, InputPIA, InputPIA_Multival, extend_list_to_batch_size, extend_to_batch_size, prepare_mask_batch
from .motion_lora import MotionLoraList
from .model_injection import MotionModelGroup, MotionModelPatcher
from .model_injection import MotionModelGroup, MotionModelPatcher, load_motion_module_gen2, inject_pia_conv_in_into_model
from .motion_module_ad import AnimateDiffFormat
from .nodes_gen2 import ApplyAnimateDiffModelNode, ADKeyframeNode

Expand Down Expand Up @@ -153,6 +154,35 @@ def apply_motion_model(self, motion_model: MotionModelPatcher, image: Tensor, va
return new_m_models


class LoadAnimateDiffAndInjectPIANode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_name": (get_available_motion_models(),),
"motion_model": ("MOTION_MODEL_ADE",),
},
"optional": {
"ad_settings": ("AD_SETTINGS",),
}
}

RETURN_TYPES = ("MOTION_MODEL_ADE",)
RETURN_NAMES = ("MOTION_MODEL",)

CATEGORY = "Animate Diff 🎭🅐🅓/② Gen2 nodes ②/PIA/🧪experimental"
FUNCTION = "load_motion_model"

def load_motion_model(self, model_name: str, motion_model: MotionModelPatcher, ad_settings: AnimateDiffSettings=None):
# make sure model actually has PIA conv_in
if motion_model.model.conv_in is None:
raise Exception("Passed-in motion model was expected to be PIA (contain conv_in), but did not.")
# load motion module and motion settings, if included
loaded_motion_model = load_motion_module_gen2(model_name=model_name, motion_model_settings=ad_settings)
inject_pia_conv_in_into_model(motion_model=loaded_motion_model, w_pia=motion_model)
return (loaded_motion_model,)


class PIA_ADKeyframeNode:
@classmethod
def INPUT_TYPES(s):
Expand Down

0 comments on commit 62b0dae

Please sign in to comment.