Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 165 additions & 90 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import redirect_stdout
from dataclasses import dataclass
from io import BytesIO
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple, Union

import numpy as np
import requests
Expand Down Expand Up @@ -126,6 +126,14 @@ class AnnotionFormat(ExplicitEnum):
COCO_PANOPTIC = AnnotationFormat.COCO_PANOPTIC.value


@dataclass
class VideoMetadata:
total_num_frames: int
fps: float
duration: float
video_backend: str


AnnotationType = Dict[str, Union[int, str, List[Dict]]]


Expand Down Expand Up @@ -541,133 +549,165 @@ def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] =
return image


def get_uniform_frame_indices(total_num_frames: int, num_frames: Optional[int] = None):
def default_sample_indices_fn(metadata: VideoMetadata, num_frames=None, fps=None, **kwargs):
"""
Creates a numpy array for uniform sampling of `num_frame` frames from `total_num_frames`
when loading a video.
A default sampling function that replicates the logic used in get_uniform_frame_indices,
while optionally handling `fps` if `num_frames` is not provided.

Args:
total_num_frames (`int`):
Total number of frames that a video has.
metadata (`VideoMetadata`):
`VideoMetadata` object containing metadat about the video, such as "total_num_frames" or "fps".
num_frames (`int`, *optional*):
Number of frames to sample uniformly. If not specified, all frames are sampled.
Number of frames to sample uniformly.
fps (`int`, *optional*):
Desired frames per second. Takes priority over num_frames if both are provided.

Returns:
np.ndarray: np array of frame indices that will be sampled.
`np.ndarray`: Array of frame indices to sample.
"""
total_num_frames = metadata.total_num_frames
video_fps = metadata.fps

# If num_frames is not given but fps is, calculate num_frames from fps
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we computed num_frames={num_frames} "
f"which exceeds total_num_frames={total_num_frames}. Check fps or video metadata."
)

if num_frames is not None:
indices = np.arange(0, total_num_frames, total_num_frames / num_frames).astype(int)
indices = np.arange(0, total_num_frames, total_num_frames / num_frames, dtype=int)
else:
indices = np.arange(0, total_num_frames).astype(int)
indices = np.arange(0, total_num_frames, dtype=int)
return indices


def read_video_opencv(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
def read_video_opencv(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with open-cv decoder.
Decode a video using the OpenCV backend.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`.
If not specified and `fps==None`, all frames are sampled.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
video = cv2.VideoCapture(video_path)
total_num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
video_fps = video.get(cv2.CAP_PROP_FPS)
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
)
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
duration = total_num_frames / video_fps if video_fps else 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random thought, what does it mean to have a duration 0 video?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An error occured within video decoder and it couldn't give us back the duration. Rarely that can happen

metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="opencv"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

index = 0
frames = []
while video.isOpened():
success, frame = video.read()
if not success:
break
if index in indices:
height, width, channel = frame.shape
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(frame[0:height, 0:width, 0:channel])
if success:
index += 1
if index >= total_num_frames:
break

video.release()
return np.stack(frames)
metadata.frames_indices = indices
return np.stack(frames), metadata


def read_video_decord(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
def read_video_decord(
video_path: str,
sample_indices_fn: Optional[Callable] = None,
**kwargs,
):
"""
Decode the video with Decord decoder.
Decode a video using the Decord backend.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`.
If not specified and `fps==None`, all frames are sampled.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
vr = VideoReader(uri=video_path, ctx=cpu(0)) # decord has problems with gpu
video_fps = vr.get_avg_fps()
total_num_frames = len(vr)
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
)
indices = get_uniform_frame_indices(total_num_frames=total_num_frames, num_frames=num_frames)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="decord"
)

indices = sample_indices_fn(metadata=metadata, **kwargs)

frames = vr.get_batch(indices).asnumpy()
return frames
metadata.frames_indices = indices
return frames, metadata


def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
def read_video_pyav(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with PyAV decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`.
If not specified and `fps==None`, all frames are sampled.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
container = av.open(video_path)

total_num_frames = container.streams.video[0].frames
video_fps = container.streams.video[0].average_rate # should we better use `av_guess_frame_rate`?
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
)
indices = get_uniform_frame_indices(total_num_frames, num_frames=num_frames)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames), fps=float(video_fps), duration=float(duration), video_backend="pyav"
)
indices = sample_indices_fn(metadata=metadata, **kwargs)

frames = []
container.seek(0)
Expand All @@ -677,48 +717,58 @@ def read_video_pyav(video_path: str, num_frames: Optional[int] = None, fps: Opti
break
if i >= 0 and i in indices:
frames.append(frame)
return np.stack([x.to_ndarray(format="rgb24") for x in frames])

video = np.stack([x.to_ndarray(format="rgb24") for x in frames])
metadata.frames_indices = indices
return video, metadata


def read_video_torchvision(video_path: str, num_frames: Optional[int] = None, fps: Optional[int] = None):
def read_video_torchvision(
video_path: str,
sample_indices_fn: Callable,
**kwargs,
):
"""
Decode the video with torchvision decoder.

Args:
video_path (`str`):
Path to the video file.
num_frames (`int`, *optional*):
Number of frames to sample uniformly. Should be passed only when `fps=None`.
If not specified and `fps==None`, all frames are sampled.
fps (`int`, *optional*):
Number of frames to sample per second. Should be passed only when `num_frames=None`.
If not specified and `num_frames==None`, all frames are sampled.
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniform sampling with fps is performed.
Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
np.ndarray: np array of decoded frames of shape (num_frames, height, width, 3).
Tuple[`np.array`, `VideoMetadata`]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- `VideoMetadata` object.
"""
video, _, info = torchvision_io.read_video(
video_path,
start_pts=0.0,
end_pts=None,
pts_unit="sec",
output_format="TCHW",
output_format="THWC",
)
video_fps = info["video_fps"]
total_num_frames = video.size(0) - 1
if num_frames is None and fps is not None:
num_frames = int(total_num_frames / video_fps * fps)
if num_frames > total_num_frames:
raise ValueError(
f"When loading the video with fps={fps}, we identified that num_frames ({num_frames}) > total_frames ({total_num_frames}) ."
f"Make sure that fps of a video is less than the requested fps for loading. Detected video_fps={video_fps}"
)
total_num_frames = video.size(0)
duration = total_num_frames / video_fps if video_fps else 0
metadata = VideoMetadata(
total_num_frames=int(total_num_frames),
fps=float(video_fps),
duration=float(duration),
video_backend="torchvision",
)

if num_frames is not None:
idx = torch.linspace(0, video.size(0) - 1, num_frames, dtype=torch.int64)
return video[idx]
indices = sample_indices_fn(metadata=metadata, **kwargs)

return video
video = video[indices].contiguous().numpy()
metadata.frames_indices = indices
return video, metadata


VIDEO_DECODERS = {
Expand All @@ -734,6 +784,8 @@ def load_video(
num_frames: Optional[int] = None,
fps: Optional[int] = None,
backend: str = "opencv",
sample_indices_fn: Optional[Callable] = None,
**kwargs,
) -> np.array:
"""
Loads `video` to a numpy array.
Expand All @@ -748,13 +800,36 @@ def load_video(
If not specified and `num_frames==None`, all frames are sampled.
backend (`str`, *optional*, defaults to `"opencv"`):
The backend to use when loading the video. Can be any of ["decord", "pyav", "opencv", "torchvision"]. Defaults to "opencv".
sample_indices_fn (`Callable`, *optional*):
A callable function that will return indices at which the video should be sampled. If the video has to be loaded using
by a different sampling technique than provided by `num_frames` or `fps` arguments, one should provide their own `sample_indices_fn`.
If not provided, simple uniformt sampling with fps is performed, otherwise `sample_indices_fn` has priority over other args.
The function expects at input the all args along with all kwargs passed to `load_video` and should output valid
indices at which the video should be sampled. For example:

Example:
def sample_indices_fn(metadata, **kwargs):
return np.linspace(0, metadata.total_num_frames - 1, num_frames, dtype=int)

Returns:
`np.array`: A numpy array of shape (num_frames, channels, height, width).
Tuple[`np.array`, Dict]: A tuple containing:
- Numpy array of frames in RGB (shape: [num_frames, height, width, 3]).
- Metadata dictionary.
"""

if fps is not None and num_frames is not None:
raise ValueError("`num_frames` and `fps` are mutually exclusive arguments, please use only one!")
# If `sample_indices_fn` is given, we can accept any args as those might be needed by custom `sample_indices_fn`
if fps is not None and num_frames is not None and sample_indices_fn is None:
raise ValueError(
"`num_frames`, `fps`, and `sample_indices_fn` are mutually exclusive arguments, please use only one!"
)

# If user didn't pass a sampling function, create one on the fly with default logic
if sample_indices_fn is None:

def sample_indices_fn_func(metadata, **fn_kwargs):
return default_sample_indices_fn(metadata, num_frames=num_frames, fps=fps, **fn_kwargs)

sample_indices_fn = sample_indices_fn_func

if video.startswith("https://www.youtube.com") or video.startswith("http://www.youtube.com"):
if not is_yt_dlp_available():
Expand Down Expand Up @@ -796,8 +871,8 @@ def load_video(
)

video_decoder = VIDEO_DECODERS[backend]
video = video_decoder(file_obj, num_frames=num_frames, fps=fps)
return video
video, metadata = video_decoder(file_obj, sample_indices_fn, **kwargs)
return video, metadata


def load_images(
Expand Down
Loading