Skip to content

Commit a7bfbac

Browse files
authored
convert Video nodes to V3 schema (comfyanonymous#9489)
1 parent 85f7594 commit a7bfbac

File tree

1 file changed

+134
-156
lines changed

1 file changed

+134
-156
lines changed

comfy_extras/nodes_video.py

Lines changed: 134 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -5,52 +5,49 @@
55
import torch
66
import folder_paths
77
import json
8-
from typing import Optional, Literal
8+
from typing import Optional
9+
from typing_extensions import override
910
from fractions import Fraction
10-
from comfy.comfy_types import IO, FileLocator, ComfyNodeABC
11-
from comfy_api.latest import Input, InputImpl, Types
11+
from comfy_api.input import AudioInput, ImageInput, VideoInput
12+
from comfy_api.input_impl import VideoFromComponents, VideoFromFile
13+
from comfy_api.util import VideoCodec, VideoComponents, VideoContainer
14+
from comfy_api.latest import ComfyExtension, io, ui
1215
from comfy.cli_args import args
1316

14-
class SaveWEBM:
15-
def __init__(self):
16-
self.output_dir = folder_paths.get_output_directory()
17-
self.type = "output"
18-
self.prefix_append = ""
19-
17+
class SaveWEBM(io.ComfyNode):
2018
@classmethod
21-
def INPUT_TYPES(s):
22-
return {"required":
23-
{"images": ("IMAGE", ),
24-
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
25-
"codec": (["vp9", "av1"],),
26-
"fps": ("FLOAT", {"default": 24.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
27-
"crf": ("FLOAT", {"default": 32.0, "min": 0, "max": 63.0, "step": 1, "tooltip": "Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."}),
28-
},
29-
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
30-
}
31-
32-
RETURN_TYPES = ()
33-
FUNCTION = "save_images"
34-
35-
OUTPUT_NODE = True
36-
37-
CATEGORY = "image/video"
38-
39-
EXPERIMENTAL = True
19+
def define_schema(cls):
20+
return io.Schema(
21+
node_id="SaveWEBM",
22+
category="image/video",
23+
is_experimental=True,
24+
inputs=[
25+
io.Image.Input("images"),
26+
io.String.Input("filename_prefix", default="ComfyUI"),
27+
io.Combo.Input("codec", options=["vp9", "av1"]),
28+
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
29+
io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."),
30+
],
31+
outputs=[],
32+
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
33+
is_output_node=True,
34+
)
4035

41-
def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, extra_pnginfo=None):
42-
filename_prefix += self.prefix_append
43-
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
36+
@classmethod
37+
def execute(cls, images, codec, fps, filename_prefix, crf) -> io.NodeOutput:
38+
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
39+
filename_prefix, folder_paths.get_output_directory(), images[0].shape[1], images[0].shape[0]
40+
)
4441

4542
file = f"{filename}_{counter:05}_.webm"
4643
container = av.open(os.path.join(full_output_folder, file), mode="w")
4744

48-
if prompt is not None:
49-
container.metadata["prompt"] = json.dumps(prompt)
45+
if cls.hidden.prompt is not None:
46+
container.metadata["prompt"] = json.dumps(cls.hidden.prompt)
5047

51-
if extra_pnginfo is not None:
52-
for x in extra_pnginfo:
53-
container.metadata[x] = json.dumps(extra_pnginfo[x])
48+
if cls.hidden.extra_pnginfo is not None:
49+
for x in cls.hidden.extra_pnginfo:
50+
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
5451

5552
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
5653
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
@@ -69,172 +66,153 @@ def save_images(self, images, codec, fps, filename_prefix, crf, prompt=None, ext
6966
container.mux(stream.encode())
7067
container.close()
7168

72-
results: list[FileLocator] = [{
73-
"filename": file,
74-
"subfolder": subfolder,
75-
"type": self.type
76-
}]
69+
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
7770

78-
return {"ui": {"images": results, "animated": (True,)}} # TODO: frontend side
79-
80-
class SaveVideo(ComfyNodeABC):
81-
def __init__(self):
82-
self.output_dir = folder_paths.get_output_directory()
83-
self.type: Literal["output"] = "output"
84-
self.prefix_append = ""
71+
class SaveVideo(io.ComfyNode):
72+
@classmethod
73+
def define_schema(cls):
74+
return io.Schema(
75+
node_id="SaveVideo",
76+
display_name="Save Video",
77+
category="image/video",
78+
description="Saves the input images to your ComfyUI output directory.",
79+
inputs=[
80+
io.Video.Input("video", tooltip="The video to save."),
81+
io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."),
82+
io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."),
83+
io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."),
84+
],
85+
outputs=[],
86+
hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo],
87+
is_output_node=True,
88+
)
8589

8690
@classmethod
87-
def INPUT_TYPES(cls):
88-
return {
89-
"required": {
90-
"video": (IO.VIDEO, {"tooltip": "The video to save."}),
91-
"filename_prefix": ("STRING", {"default": "video/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}),
92-
"format": (Types.VideoContainer.as_input(), {"default": "auto", "tooltip": "The format to save the video as."}),
93-
"codec": (Types.VideoCodec.as_input(), {"default": "auto", "tooltip": "The codec to use for the video."}),
94-
},
95-
"hidden": {
96-
"prompt": "PROMPT",
97-
"extra_pnginfo": "EXTRA_PNGINFO"
98-
},
99-
}
100-
101-
RETURN_TYPES = ()
102-
FUNCTION = "save_video"
103-
104-
OUTPUT_NODE = True
105-
106-
CATEGORY = "image/video"
107-
DESCRIPTION = "Saves the input images to your ComfyUI output directory."
108-
109-
def save_video(self, video: Input.Video, filename_prefix, format, codec, prompt=None, extra_pnginfo=None):
110-
filename_prefix += self.prefix_append
91+
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
11192
width, height = video.get_dimensions()
11293
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
11394
filename_prefix,
114-
self.output_dir,
95+
folder_paths.get_output_directory(),
11596
width,
11697
height
11798
)
118-
results: list[FileLocator] = list()
11999
saved_metadata = None
120100
if not args.disable_metadata:
121101
metadata = {}
122-
if extra_pnginfo is not None:
123-
metadata.update(extra_pnginfo)
124-
if prompt is not None:
125-
metadata["prompt"] = prompt
102+
if cls.hidden.extra_pnginfo is not None:
103+
metadata.update(cls.hidden.extra_pnginfo)
104+
if cls.hidden.prompt is not None:
105+
metadata["prompt"] = cls.hidden.prompt
126106
if len(metadata) > 0:
127107
saved_metadata = metadata
128-
file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}"
108+
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
129109
video.save_to(
130110
os.path.join(full_output_folder, file),
131111
format=format,
132112
codec=codec,
133113
metadata=saved_metadata
134114
)
135115

136-
results.append({
137-
"filename": file,
138-
"subfolder": subfolder,
139-
"type": self.type
140-
})
141-
counter += 1
116+
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)]))
142117

143-
return { "ui": { "images": results, "animated": (True,) } }
144118

145-
class CreateVideo(ComfyNodeABC):
119+
class CreateVideo(io.ComfyNode):
146120
@classmethod
147-
def INPUT_TYPES(cls):
148-
return {
149-
"required": {
150-
"images": (IO.IMAGE, {"tooltip": "The images to create a video from."}),
151-
"fps": ("FLOAT", {"default": 30.0, "min": 1.0, "max": 120.0, "step": 1.0}),
152-
},
153-
"optional": {
154-
"audio": (IO.AUDIO, {"tooltip": "The audio to add to the video."}),
155-
}
156-
}
157-
158-
RETURN_TYPES = (IO.VIDEO,)
159-
FUNCTION = "create_video"
160-
161-
CATEGORY = "image/video"
162-
DESCRIPTION = "Create a video from images."
163-
164-
def create_video(self, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None):
165-
return (InputImpl.VideoFromComponents(
166-
Types.VideoComponents(
167-
images=images,
168-
audio=audio,
169-
frame_rate=Fraction(fps),
170-
)
171-
),)
172-
173-
class GetVideoComponents(ComfyNodeABC):
121+
def define_schema(cls):
122+
return io.Schema(
123+
node_id="CreateVideo",
124+
display_name="Create Video",
125+
category="image/video",
126+
description="Create a video from images.",
127+
inputs=[
128+
io.Image.Input("images", tooltip="The images to create a video from."),
129+
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
130+
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
131+
],
132+
outputs=[
133+
io.Video.Output(),
134+
],
135+
)
136+
137+
@classmethod
138+
def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput:
139+
return io.NodeOutput(
140+
VideoFromComponents(VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
141+
)
142+
143+
class GetVideoComponents(io.ComfyNode):
144+
@classmethod
145+
def define_schema(cls):
146+
return io.Schema(
147+
node_id="GetVideoComponents",
148+
display_name="Get Video Components",
149+
category="image/video",
150+
description="Extracts all components from a video: frames, audio, and framerate.",
151+
inputs=[
152+
io.Video.Input("video", tooltip="The video to extract components from."),
153+
],
154+
outputs=[
155+
io.Image.Output(display_name="images"),
156+
io.Audio.Output(display_name="audio"),
157+
io.Float.Output(display_name="fps"),
158+
],
159+
)
160+
174161
@classmethod
175-
def INPUT_TYPES(cls):
176-
return {
177-
"required": {
178-
"video": (IO.VIDEO, {"tooltip": "The video to extract components from."}),
179-
}
180-
}
181-
RETURN_TYPES = (IO.IMAGE, IO.AUDIO, IO.FLOAT)
182-
RETURN_NAMES = ("images", "audio", "fps")
183-
FUNCTION = "get_components"
184-
185-
CATEGORY = "image/video"
186-
DESCRIPTION = "Extracts all components from a video: frames, audio, and framerate."
187-
188-
def get_components(self, video: Input.Video):
162+
def execute(cls, video: VideoInput) -> io.NodeOutput:
189163
components = video.get_components()
190164

191-
return (components.images, components.audio, float(components.frame_rate))
165+
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
192166

193-
class LoadVideo(ComfyNodeABC):
167+
class LoadVideo(io.ComfyNode):
194168
@classmethod
195-
def INPUT_TYPES(cls):
169+
def define_schema(cls):
196170
input_dir = folder_paths.get_input_directory()
197171
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
198172
files = folder_paths.filter_files_content_types(files, ["video"])
199-
return {"required":
200-
{"file": (sorted(files), {"video_upload": True})},
201-
}
202-
203-
CATEGORY = "image/video"
173+
return io.Schema(
174+
node_id="LoadVideo",
175+
display_name="Load Video",
176+
category="image/video",
177+
inputs=[
178+
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
179+
],
180+
outputs=[
181+
io.Video.Output(),
182+
],
183+
)
204184

205-
RETURN_TYPES = (IO.VIDEO,)
206-
FUNCTION = "load_video"
207-
def load_video(self, file):
185+
@classmethod
186+
def execute(cls, file) -> io.NodeOutput:
208187
video_path = folder_paths.get_annotated_filepath(file)
209-
return (InputImpl.VideoFromFile(video_path),)
188+
return io.NodeOutput(VideoFromFile(video_path))
210189

211190
@classmethod
212-
def IS_CHANGED(cls, file):
191+
def fingerprint_inputs(s, file):
213192
video_path = folder_paths.get_annotated_filepath(file)
214193
mod_time = os.path.getmtime(video_path)
215194
# Instead of hashing the file, we can just use the modification time to avoid
216195
# rehashing large files.
217196
return mod_time
218197

219198
@classmethod
220-
def VALIDATE_INPUTS(cls, file):
199+
def validate_inputs(s, file):
221200
if not folder_paths.exists_annotated_filepath(file):
222201
return "Invalid video file: {}".format(file)
223202

224203
return True
225204

226-
NODE_CLASS_MAPPINGS = {
227-
"SaveWEBM": SaveWEBM,
228-
"SaveVideo": SaveVideo,
229-
"CreateVideo": CreateVideo,
230-
"GetVideoComponents": GetVideoComponents,
231-
"LoadVideo": LoadVideo,
232-
}
233-
234-
NODE_DISPLAY_NAME_MAPPINGS = {
235-
"SaveVideo": "Save Video",
236-
"CreateVideo": "Create Video",
237-
"GetVideoComponents": "Get Video Components",
238-
"LoadVideo": "Load Video",
239-
}
240205

206+
class VideoExtension(ComfyExtension):
207+
@override
208+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
209+
return [
210+
SaveWEBM,
211+
SaveVideo,
212+
CreateVideo,
213+
GetVideoComponents,
214+
LoadVideo,
215+
]
216+
217+
async def comfy_entrypoint() -> VideoExtension:
218+
return VideoExtension()

0 commit comments

Comments
 (0)