Skip to content

Commit

Permalink
Update nodes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 4, 2024
1 parent 16fca6a commit ef8d609
Showing 1 changed file with 31 additions and 8 deletions.
39 changes: 31 additions & 8 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,18 @@ def loadmodel(self, model, base_precision, load_device, quantization,
#compile
if compile_args is not None:
torch._dynamo.config.cache_size_limit = compile_args["dynamo_cache_size_limit"]
for i, block in enumerate(transformer.single_blocks):
transformer.single_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
for i, block in enumerate(transformer.double_blocks):
transformer.double_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if compile_args["compile_single_blocks"]:
for i, block in enumerate(transformer.single_blocks):
transformer.single_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if compile_args["compile_double_blocks"]:
for i, block in enumerate(transformer.double_blocks):
transformer.double_blocks[i] = torch.compile(block, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if compile_args["compile_txt_in"]:
transformer.txt_in = torch.compile(transformer.txt_in, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if compile_args["compile_vector_in"]:
transformer.vector_in = torch.compile(transformer.vector_in, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])
if compile_args["compile_final_layer"]:
transformer.final_layer = torch.compile(transformer.final_layer, fullgraph=compile_args["fullgraph"], dynamic=compile_args["dynamic"], backend=compile_args["backend"], mode=compile_args["mode"])

if "torchao" in quantization:
try:
Expand Down Expand Up @@ -294,6 +302,12 @@ def INPUT_TYPES(s):
"mode": (["default", "max-autotune", "max-autotune-no-cudagraphs", "reduce-overhead"], {"default": "default"}),
"dynamic": ("BOOLEAN", {"default": False, "tooltip": "Enable dynamic mode"}),
"dynamo_cache_size_limit": ("INT", {"default": 64, "min": 0, "max": 1024, "step": 1, "tooltip": "torch._dynamo.config.cache_size_limit"}),
"compile_single_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile single blocks"}),
"compile_double_blocks": ("BOOLEAN", {"default": True, "tooltip": "Compile double blocks"}),
"compile_txt_in": ("BOOLEAN", {"default": False, "tooltip": "Compile txt_in layers"}),
"compile_vector_in": ("BOOLEAN", {"default": False, "tooltip": "Compile vector_in layers"}),
"compile_final_layer": ("BOOLEAN", {"default": False, "tooltip": "Compile final layer"}),

},
}
RETURN_TYPES = ("COMPILEARGS",)
Expand All @@ -302,14 +316,19 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "torch.compile settings, when connected to the model loader, torch.compile of the selected layers is attempted. Requires Triton and torch 2.5.0 is recommended"

def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit):
def loadmodel(self, backend, fullgraph, mode, dynamic, dynamo_cache_size_limit, compile_single_blocks, compile_double_blocks, compile_txt_in, compile_vector_in, compile_final_layer):

compile_args = {
"backend": backend,
"fullgraph": fullgraph,
"mode": mode,
"dynamic": dynamic,
"dynamo_cache_size_limit": dynamo_cache_size_limit,
"compile_single_blocks": compile_single_blocks,
"compile_double_blocks": compile_double_blocks,
"compile_txt_in": compile_txt_in,
"compile_vector_in": compile_vector_in,
"compile_final_layer": compile_final_layer
}

return (compile_args, )
Expand Down Expand Up @@ -414,7 +433,7 @@ def INPUT_TYPES(s):
return {"required": {
"text_encoders": ("HYVIDTEXTENCODER",),
"prompt": ("STRING", {"default": "", "multiline": True} ),
"negative_prompt": ("STRING", {"default": "", "multiline": True}),
#"negative_prompt": ("STRING", {"default": "", "multiline": True}),
},
"optional": {
"force_offload": ("BOOLEAN", {"default": True}),
Expand All @@ -426,19 +445,21 @@ def INPUT_TYPES(s):
FUNCTION = "process"
CATEGORY = "HunyuanVideoWrapper"

def process(self, text_encoders, prompt, negative_prompt, force_offload=True):
def process(self, text_encoders, prompt, force_offload=True):
device = mm.text_encoder_device()
offload_device = mm.text_encoder_offload_device()

text_encoder_1 = text_encoders["text_encoder"]
text_encoder_2 = text_encoders["text_encoder_2"]

negative_prompt = None

def encode_prompt(self, prompt, negative_prompt, text_encoder):
batch_size = 1
num_videos_per_prompt = 1
do_classifier_free_guidance = True
data_type = "video"

text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)

prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
Expand Down Expand Up @@ -860,6 +881,7 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"DownloadAndLoadHyVideoTextEncoder": DownloadAndLoadHyVideoTextEncoder,
"HyVideoEncode": HyVideoEncode,
"HyVideoBlockSwap": HyVideoBlockSwap,
"HyVideoTorchCompileSettings": HyVideoTorchCompileSettings,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HyVideoSampler": "HunyuanVideo Sampler",
Expand All @@ -870,4 +892,5 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"DownloadAndLoadHyVideoTextEncoder": "(Down)Load HunyuanVideo TextEncoder",
"HyVideoEncode": "HunyuanVideo Encode",
"HyVideoBlockSwap": "HunyuanVideo BlockSwap",
"HyVideoTorchCompileSettings": "HunyuanVideo Torch Compile Settings",
}

0 comments on commit ef8d609

Please sign in to comment.