|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +import tempfile |
| 4 | +import os |
| 5 | +import av |
| 6 | +import io |
| 7 | +from fractions import Fraction |
| 8 | +from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents |
| 9 | +from comfy_api.util.video_types import VideoComponents |
| 10 | +from comfy_api.input.basic_types import AudioInput |
| 11 | +from av.error import InvalidDataError |
| 12 | + |
| 13 | +EPSILON = 0.0001 |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture |
| 17 | +def sample_images(): |
| 18 | + """3-frame 2x2 RGB video tensor""" |
| 19 | + return torch.rand(3, 2, 2, 3) |
| 20 | + |
| 21 | + |
| 22 | +@pytest.fixture |
| 23 | +def sample_audio(): |
| 24 | + """Stereo audio with 44.1kHz sample rate""" |
| 25 | + return AudioInput( |
| 26 | + { |
| 27 | + "waveform": torch.rand(1, 2, 1000), |
| 28 | + "sample_rate": 44100, |
| 29 | + } |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +@pytest.fixture |
| 34 | +def video_components(sample_images, sample_audio): |
| 35 | + """VideoComponents with images, audio, and metadata""" |
| 36 | + return VideoComponents( |
| 37 | + images=sample_images, |
| 38 | + audio=sample_audio, |
| 39 | + frame_rate=Fraction(30), |
| 40 | + metadata={"test": "metadata"}, |
| 41 | + ) |
| 42 | + |
| 43 | + |
| 44 | +def create_test_video(width=4, height=4, frames=3, fps=30): |
| 45 | + """Helper to create a temporary video file""" |
| 46 | + tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) |
| 47 | + with av.open(tmp.name, mode="w") as container: |
| 48 | + stream = container.add_stream("h264", rate=fps) |
| 49 | + stream.width = width |
| 50 | + stream.height = height |
| 51 | + stream.pix_fmt = "yuv420p" |
| 52 | + |
| 53 | + for i in range(frames): |
| 54 | + frame = av.VideoFrame.from_ndarray( |
| 55 | + torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85), |
| 56 | + format="rgb24", |
| 57 | + ) |
| 58 | + frame = frame.reformat(format="yuv420p") |
| 59 | + packet = stream.encode(frame) |
| 60 | + container.mux(packet) |
| 61 | + |
| 62 | + # Flush |
| 63 | + packet = stream.encode(None) |
| 64 | + container.mux(packet) |
| 65 | + |
| 66 | + return tmp.name |
| 67 | + |
| 68 | + |
| 69 | +@pytest.fixture |
| 70 | +def simple_video_file(): |
| 71 | + """4x4 video with 3 frames at 30fps""" |
| 72 | + file_path = create_test_video() |
| 73 | + yield file_path |
| 74 | + os.unlink(file_path) |
| 75 | + |
| 76 | + |
| 77 | +def test_video_from_components_get_duration(video_components): |
| 78 | + """Duration calculated correctly from frame count and frame rate""" |
| 79 | + video = VideoFromComponents(video_components) |
| 80 | + duration = video.get_duration() |
| 81 | + |
| 82 | + expected_duration = 3.0 / 30.0 |
| 83 | + assert duration == pytest.approx(expected_duration) |
| 84 | + |
| 85 | + |
| 86 | +def test_video_from_components_get_duration_different_frame_rates(sample_images): |
| 87 | + """Duration correct for different frame rates including fractional""" |
| 88 | + # Test with 60 fps |
| 89 | + components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60)) |
| 90 | + video_60fps = VideoFromComponents(components_60fps) |
| 91 | + assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0) |
| 92 | + |
| 93 | + # Test with fractional frame rate (23.976fps) |
| 94 | + components_frac = VideoComponents( |
| 95 | + images=sample_images, frame_rate=Fraction(24000, 1001) |
| 96 | + ) |
| 97 | + video_frac = VideoFromComponents(components_frac) |
| 98 | + expected_frac = 3.0 / (24000.0 / 1001.0) |
| 99 | + assert video_frac.get_duration() == pytest.approx(expected_frac) |
| 100 | + |
| 101 | + |
| 102 | +def test_video_from_components_get_duration_empty_video(): |
| 103 | + """Duration is zero for empty video""" |
| 104 | + empty_components = VideoComponents( |
| 105 | + images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30) |
| 106 | + ) |
| 107 | + video = VideoFromComponents(empty_components) |
| 108 | + assert video.get_duration() == 0.0 |
| 109 | + |
| 110 | + |
| 111 | +def test_video_from_components_get_dimensions(video_components): |
| 112 | + """Dimensions returned correctly from image tensor shape""" |
| 113 | + video = VideoFromComponents(video_components) |
| 114 | + width, height = video.get_dimensions() |
| 115 | + assert width == 2 |
| 116 | + assert height == 2 |
| 117 | + |
| 118 | + |
| 119 | +def test_video_from_file_get_duration(simple_video_file): |
| 120 | + """Duration extracted from file metadata""" |
| 121 | + video = VideoFromFile(simple_video_file) |
| 122 | + duration = video.get_duration() |
| 123 | + assert duration == pytest.approx(0.1, abs=0.01) |
| 124 | + |
| 125 | + |
| 126 | +def test_video_from_file_get_dimensions(simple_video_file): |
| 127 | + """Dimensions read from stream without decoding frames""" |
| 128 | + video = VideoFromFile(simple_video_file) |
| 129 | + width, height = video.get_dimensions() |
| 130 | + assert width == 4 |
| 131 | + assert height == 4 |
| 132 | + |
| 133 | + |
| 134 | +def test_video_from_file_bytesio_input(): |
| 135 | + """VideoFromFile works with BytesIO input""" |
| 136 | + buffer = io.BytesIO() |
| 137 | + with av.open(buffer, mode="w", format="mp4") as container: |
| 138 | + stream = container.add_stream("h264", rate=30) |
| 139 | + stream.width = 2 |
| 140 | + stream.height = 2 |
| 141 | + stream.pix_fmt = "yuv420p" |
| 142 | + |
| 143 | + frame = av.VideoFrame.from_ndarray( |
| 144 | + torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24" |
| 145 | + ) |
| 146 | + frame = frame.reformat(format="yuv420p") |
| 147 | + packet = stream.encode(frame) |
| 148 | + container.mux(packet) |
| 149 | + packet = stream.encode(None) |
| 150 | + container.mux(packet) |
| 151 | + |
| 152 | + buffer.seek(0) |
| 153 | + video = VideoFromFile(buffer) |
| 154 | + |
| 155 | + assert video.get_dimensions() == (2, 2) |
| 156 | + assert video.get_duration() == pytest.approx(1 / 30, abs=0.01) |
| 157 | + |
| 158 | + |
| 159 | +def test_video_from_file_invalid_file_error(): |
| 160 | + """InvalidDataError raised for non-video files""" |
| 161 | + with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp: |
| 162 | + tmp.write(b"not a video file") |
| 163 | + tmp.flush() |
| 164 | + tmp_name = tmp.name |
| 165 | + |
| 166 | + try: |
| 167 | + with pytest.raises(InvalidDataError): |
| 168 | + video = VideoFromFile(tmp_name) |
| 169 | + video.get_dimensions() |
| 170 | + finally: |
| 171 | + os.unlink(tmp_name) |
| 172 | + |
| 173 | + |
| 174 | +def test_video_from_file_audio_only_error(): |
| 175 | + """ValueError raised for audio-only files""" |
| 176 | + with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp: |
| 177 | + tmp_name = tmp.name |
| 178 | + |
| 179 | + try: |
| 180 | + with av.open(tmp_name, mode="w") as container: |
| 181 | + stream = container.add_stream("aac", rate=44100) |
| 182 | + stream.sample_rate = 44100 |
| 183 | + stream.format = "fltp" |
| 184 | + |
| 185 | + audio_data = torch.zeros(1, 1024).numpy() |
| 186 | + audio_frame = av.AudioFrame.from_ndarray( |
| 187 | + audio_data, format="fltp", layout="mono" |
| 188 | + ) |
| 189 | + audio_frame.sample_rate = 44100 |
| 190 | + audio_frame.pts = 0 |
| 191 | + packet = stream.encode(audio_frame) |
| 192 | + container.mux(packet) |
| 193 | + |
| 194 | + for packet in stream.encode(None): |
| 195 | + container.mux(packet) |
| 196 | + |
| 197 | + with pytest.raises(ValueError, match="No video stream found"): |
| 198 | + video = VideoFromFile(tmp_name) |
| 199 | + video.get_dimensions() |
| 200 | + finally: |
| 201 | + os.unlink(tmp_name) |
| 202 | + |
| 203 | + |
| 204 | +def test_single_frame_video(): |
| 205 | + """Single frame video has correct duration""" |
| 206 | + components = VideoComponents( |
| 207 | + images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1) |
| 208 | + ) |
| 209 | + video = VideoFromComponents(components) |
| 210 | + assert video.get_duration() == 1.0 |
| 211 | + |
| 212 | + |
| 213 | +@pytest.mark.parametrize( |
| 214 | + "frame_rate,expected_fps", |
| 215 | + [ |
| 216 | + (Fraction(24000, 1001), 24000 / 1001), |
| 217 | + (Fraction(30000, 1001), 30000 / 1001), |
| 218 | + (Fraction(25, 1), 25.0), |
| 219 | + (Fraction(50, 2), 25.0), |
| 220 | + ], |
| 221 | +) |
| 222 | +def test_fractional_frame_rates(frame_rate, expected_fps): |
| 223 | + """Duration calculated correctly for various fractional frame rates""" |
| 224 | + components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate) |
| 225 | + video = VideoFromComponents(components) |
| 226 | + duration = video.get_duration() |
| 227 | + expected_duration = 100.0 / expected_fps |
| 228 | + assert duration == pytest.approx(expected_duration) |
| 229 | + |
| 230 | + |
| 231 | +def test_duration_consistency(video_components): |
| 232 | + """get_duration() consistent with manual calculation from components""" |
| 233 | + video = VideoFromComponents(video_components) |
| 234 | + |
| 235 | + duration = video.get_duration() |
| 236 | + components = video.get_components() |
| 237 | + manual_duration = float(components.images.shape[0] / components.frame_rate) |
| 238 | + |
| 239 | + assert duration == pytest.approx(manual_duration) |
0 commit comments