Skip to content

Commit f1f9763

Browse files
Add get_duration method to Comfy VIDEO type (Comfy-Org#8122)
* get duration from VIDEO type * video get_duration unit test * fix Windows unit test: can't delete opened temp file
1 parent 08368f8 commit f1f9763

File tree

3 files changed

+281
-0
lines changed

3 files changed

+281
-0
lines changed

comfy_api/input/video_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,13 @@ def get_dimensions(self) -> tuple[int, int]:
4343
components = self.get_components()
4444
return components.images.shape[2], components.images.shape[1]
4545

46+
def get_duration(self) -> float:
47+
"""
48+
Returns the duration of the video in seconds.
49+
50+
Returns:
51+
Duration in seconds
52+
"""
53+
components = self.get_components()
54+
frame_count = components.images.shape[0]
55+
return float(frame_count / components.frame_rate)

comfy_api/input_impl/video_types.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,38 @@ def get_dimensions(self) -> tuple[int, int]:
8080
return stream.width, stream.height
8181
raise ValueError(f"No video stream found in file '{self.__file}'")
8282

83+
def get_duration(self) -> float:
84+
"""
85+
Returns the duration of the video in seconds.
86+
87+
Returns:
88+
Duration in seconds
89+
"""
90+
if isinstance(self.__file, io.BytesIO):
91+
self.__file.seek(0)
92+
with av.open(self.__file, mode="r") as container:
93+
if container.duration is not None:
94+
return float(container.duration / av.time_base)
95+
96+
# Fallback: calculate from frame count and frame rate
97+
video_stream = next(
98+
(s for s in container.streams if s.type == "video"), None
99+
)
100+
if video_stream and video_stream.frames and video_stream.average_rate:
101+
return float(video_stream.frames / video_stream.average_rate)
102+
103+
# Last resort: decode frames to count them
104+
if video_stream and video_stream.average_rate:
105+
frame_count = 0
106+
container.seek(0)
107+
for packet in container.demux(video_stream):
108+
for _ in packet.decode():
109+
frame_count += 1
110+
if frame_count > 0:
111+
return float(frame_count / video_stream.average_rate)
112+
113+
raise ValueError(f"Could not determine duration for file '{self.__file}'")
114+
83115
def get_components_internal(self, container: InputContainer) -> VideoComponents:
84116
# Get video frames
85117
frames = []
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import pytest
2+
import torch
3+
import tempfile
4+
import os
5+
import av
6+
import io
7+
from fractions import Fraction
8+
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
9+
from comfy_api.util.video_types import VideoComponents
10+
from comfy_api.input.basic_types import AudioInput
11+
from av.error import InvalidDataError
12+
13+
EPSILON = 0.0001
14+
15+
16+
@pytest.fixture
17+
def sample_images():
18+
"""3-frame 2x2 RGB video tensor"""
19+
return torch.rand(3, 2, 2, 3)
20+
21+
22+
@pytest.fixture
23+
def sample_audio():
24+
"""Stereo audio with 44.1kHz sample rate"""
25+
return AudioInput(
26+
{
27+
"waveform": torch.rand(1, 2, 1000),
28+
"sample_rate": 44100,
29+
}
30+
)
31+
32+
33+
@pytest.fixture
34+
def video_components(sample_images, sample_audio):
35+
"""VideoComponents with images, audio, and metadata"""
36+
return VideoComponents(
37+
images=sample_images,
38+
audio=sample_audio,
39+
frame_rate=Fraction(30),
40+
metadata={"test": "metadata"},
41+
)
42+
43+
44+
def create_test_video(width=4, height=4, frames=3, fps=30):
45+
"""Helper to create a temporary video file"""
46+
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
47+
with av.open(tmp.name, mode="w") as container:
48+
stream = container.add_stream("h264", rate=fps)
49+
stream.width = width
50+
stream.height = height
51+
stream.pix_fmt = "yuv420p"
52+
53+
for i in range(frames):
54+
frame = av.VideoFrame.from_ndarray(
55+
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
56+
format="rgb24",
57+
)
58+
frame = frame.reformat(format="yuv420p")
59+
packet = stream.encode(frame)
60+
container.mux(packet)
61+
62+
# Flush
63+
packet = stream.encode(None)
64+
container.mux(packet)
65+
66+
return tmp.name
67+
68+
69+
@pytest.fixture
70+
def simple_video_file():
71+
"""4x4 video with 3 frames at 30fps"""
72+
file_path = create_test_video()
73+
yield file_path
74+
os.unlink(file_path)
75+
76+
77+
def test_video_from_components_get_duration(video_components):
78+
"""Duration calculated correctly from frame count and frame rate"""
79+
video = VideoFromComponents(video_components)
80+
duration = video.get_duration()
81+
82+
expected_duration = 3.0 / 30.0
83+
assert duration == pytest.approx(expected_duration)
84+
85+
86+
def test_video_from_components_get_duration_different_frame_rates(sample_images):
87+
"""Duration correct for different frame rates including fractional"""
88+
# Test with 60 fps
89+
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
90+
video_60fps = VideoFromComponents(components_60fps)
91+
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
92+
93+
# Test with fractional frame rate (23.976fps)
94+
components_frac = VideoComponents(
95+
images=sample_images, frame_rate=Fraction(24000, 1001)
96+
)
97+
video_frac = VideoFromComponents(components_frac)
98+
expected_frac = 3.0 / (24000.0 / 1001.0)
99+
assert video_frac.get_duration() == pytest.approx(expected_frac)
100+
101+
102+
def test_video_from_components_get_duration_empty_video():
103+
"""Duration is zero for empty video"""
104+
empty_components = VideoComponents(
105+
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
106+
)
107+
video = VideoFromComponents(empty_components)
108+
assert video.get_duration() == 0.0
109+
110+
111+
def test_video_from_components_get_dimensions(video_components):
112+
"""Dimensions returned correctly from image tensor shape"""
113+
video = VideoFromComponents(video_components)
114+
width, height = video.get_dimensions()
115+
assert width == 2
116+
assert height == 2
117+
118+
119+
def test_video_from_file_get_duration(simple_video_file):
120+
"""Duration extracted from file metadata"""
121+
video = VideoFromFile(simple_video_file)
122+
duration = video.get_duration()
123+
assert duration == pytest.approx(0.1, abs=0.01)
124+
125+
126+
def test_video_from_file_get_dimensions(simple_video_file):
127+
"""Dimensions read from stream without decoding frames"""
128+
video = VideoFromFile(simple_video_file)
129+
width, height = video.get_dimensions()
130+
assert width == 4
131+
assert height == 4
132+
133+
134+
def test_video_from_file_bytesio_input():
135+
"""VideoFromFile works with BytesIO input"""
136+
buffer = io.BytesIO()
137+
with av.open(buffer, mode="w", format="mp4") as container:
138+
stream = container.add_stream("h264", rate=30)
139+
stream.width = 2
140+
stream.height = 2
141+
stream.pix_fmt = "yuv420p"
142+
143+
frame = av.VideoFrame.from_ndarray(
144+
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
145+
)
146+
frame = frame.reformat(format="yuv420p")
147+
packet = stream.encode(frame)
148+
container.mux(packet)
149+
packet = stream.encode(None)
150+
container.mux(packet)
151+
152+
buffer.seek(0)
153+
video = VideoFromFile(buffer)
154+
155+
assert video.get_dimensions() == (2, 2)
156+
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
157+
158+
159+
def test_video_from_file_invalid_file_error():
160+
"""InvalidDataError raised for non-video files"""
161+
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
162+
tmp.write(b"not a video file")
163+
tmp.flush()
164+
tmp_name = tmp.name
165+
166+
try:
167+
with pytest.raises(InvalidDataError):
168+
video = VideoFromFile(tmp_name)
169+
video.get_dimensions()
170+
finally:
171+
os.unlink(tmp_name)
172+
173+
174+
def test_video_from_file_audio_only_error():
175+
"""ValueError raised for audio-only files"""
176+
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
177+
tmp_name = tmp.name
178+
179+
try:
180+
with av.open(tmp_name, mode="w") as container:
181+
stream = container.add_stream("aac", rate=44100)
182+
stream.sample_rate = 44100
183+
stream.format = "fltp"
184+
185+
audio_data = torch.zeros(1, 1024).numpy()
186+
audio_frame = av.AudioFrame.from_ndarray(
187+
audio_data, format="fltp", layout="mono"
188+
)
189+
audio_frame.sample_rate = 44100
190+
audio_frame.pts = 0
191+
packet = stream.encode(audio_frame)
192+
container.mux(packet)
193+
194+
for packet in stream.encode(None):
195+
container.mux(packet)
196+
197+
with pytest.raises(ValueError, match="No video stream found"):
198+
video = VideoFromFile(tmp_name)
199+
video.get_dimensions()
200+
finally:
201+
os.unlink(tmp_name)
202+
203+
204+
def test_single_frame_video():
205+
"""Single frame video has correct duration"""
206+
components = VideoComponents(
207+
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
208+
)
209+
video = VideoFromComponents(components)
210+
assert video.get_duration() == 1.0
211+
212+
213+
@pytest.mark.parametrize(
214+
"frame_rate,expected_fps",
215+
[
216+
(Fraction(24000, 1001), 24000 / 1001),
217+
(Fraction(30000, 1001), 30000 / 1001),
218+
(Fraction(25, 1), 25.0),
219+
(Fraction(50, 2), 25.0),
220+
],
221+
)
222+
def test_fractional_frame_rates(frame_rate, expected_fps):
223+
"""Duration calculated correctly for various fractional frame rates"""
224+
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
225+
video = VideoFromComponents(components)
226+
duration = video.get_duration()
227+
expected_duration = 100.0 / expected_fps
228+
assert duration == pytest.approx(expected_duration)
229+
230+
231+
def test_duration_consistency(video_components):
232+
"""get_duration() consistent with manual calculation from components"""
233+
video = VideoFromComponents(video_components)
234+
235+
duration = video.get_duration()
236+
components = video.get_components()
237+
manual_duration = float(components.images.shape[0] / components.frame_rate)
238+
239+
assert duration == pytest.approx(manual_duration)

0 commit comments

Comments
 (0)