Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions mediapy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@
from __future__ import annotations

__docformat__ = 'google'
__version__ = '1.1.9'
__version__ = '1.2.0'
__version_info__ = tuple(int(num) for num in __version__.split('.'))

import base64
Expand All @@ -116,7 +116,6 @@
import itertools
import math
import numbers
import os
import pathlib
import re
import shlex
Expand All @@ -129,12 +128,15 @@
import urllib.request

import IPython.display
import matplotlib
import matplotlib.pyplot
import numpy as np
import numpy.typing as npt
import PIL.Image
import PIL.ImageOps

if typing.TYPE_CHECKING:
import os

if not hasattr(PIL.Image, 'Resampling'): # Allow Pillow<9.0.
PIL.Image.Resampling = PIL.Image

Expand Down Expand Up @@ -187,7 +189,7 @@

_IPYTHON_HTML_SIZE_LIMIT = 20_000_000
_T = typing.TypeVar('_T')
_Path = typing.Union[str, os.PathLike]
_Path = typing.Union[str, 'os.PathLike[str]']

_IMAGE_COMPARISON_HTML = """\
<script
Expand Down Expand Up @@ -504,10 +506,10 @@ def generate_image(image_index: int) -> _NDArray:
"""Returns a video frame image."""
image = color_ramp(shape, dtype=dtype)
yx = np.moveaxis(np.indices(shape), 0, -1)
center = (shape[0] * 0.6, shape[1] * (image_index + 0.5) / num_images)
center = shape[0] * 0.6, shape[1] * (image_index + 0.5) / num_images
radius_squared = (min(shape) * 0.1) ** 2
inside = np.sum((yx - center) ** 2, axis=-1) < radius_squared
white_circle_color = (1.0, 1.0, 1.0)
white_circle_color = 1.0, 1.0, 1.0
if np.issubdtype(dtype, np.unsignedinteger):
white_circle_color = to_type([white_circle_color], dtype)[0]
image[inside] = white_circle_color
Expand Down Expand Up @@ -837,7 +839,7 @@ def to_rgb(
a = (a.astype('float') - vmin) / (vmax - vmin + np.finfo(float).eps)
if isinstance(cmap, str):
if hasattr(matplotlib, 'colormaps'):
rgb_from_scalar = matplotlib.colormaps[cmap] # Newer version.
rgb_from_scalar: Any = matplotlib.colormaps[cmap] # Newer version.
else:
rgb_from_scalar = matplotlib.pyplot.cm.get_cmap(cmap)
else:
Expand Down Expand Up @@ -1234,19 +1236,20 @@ def _get_video_metadata(path: _Path) -> VideoMetadata:
fps = 10
else:
raise RuntimeError(f'Unable to parse video framerate in line {line}')
if (
(match := re.fullmatch(r'\s*rotate\s*:\s*(\d+)', line)) or
(match := re.fullmatch(r'\s*.*rotation of -?(\d+)\s*.*\sdegrees', line))
):
if match := re.fullmatch(r'\s*rotate\s*:\s*(\d+)', line):
rotation = int(match.group(1))
if match := re.fullmatch(r'.*rotation of (-?\d+).*\sdegrees', line):
rotation = int(match.group(1))
if not num_images:
raise RuntimeError(f'Unable to find frames in video: {err}')
if not width:
raise RuntimeError(f'Unable to parse video header: {err}')
# By default, ffmpeg enables "-autorotate"; we just fix the dimensions.
if rotation in (90, 270):
if rotation in (90, 270, -90, -270):
width, height = height, width
shape = (height, width)
assert height is not None and width is not None
shape = height, width
assert fps is not None
return VideoMetadata(num_images, shape, fps, bps)


Expand Down Expand Up @@ -1388,7 +1391,8 @@ def read(self) -> _NDArray | None:
array with 3 color channels, except for format 'gray' which is 2D.
"""
assert self._proc, 'Error: reading from an already closed context.'
assert (stdout := self._proc.stdout) is not None
stdout = self._proc.stdout
assert stdout is not None
data = stdout.read(self._num_bytes_per_image)
if not data: # Due to either end-of-file or subprocess error.
self.close() # Raises exception if subprocess had error.
Expand Down Expand Up @@ -1427,7 +1431,7 @@ def close(self) -> None:
class VideoWriter(_VideoIO):
"""Context to write a compressed video.

>>> shape = (480, 640)
>>> shape = 480, 640
>>> with VideoWriter('/tmp/v.mp4', shape, fps=60) as writer:
... for image in moving_circle(shape, num_images=60):
... writer.add_image(image)
Expand Down Expand Up @@ -1644,7 +1648,8 @@ def add_image(self, image: _NDArray) -> None:
if self.input_format == 'yuv': # Convert from per-pixel YUV to planar YUV.
image = np.moveaxis(image, 2, 0)
data = image.tobytes()
assert (stdin := self._proc.stdin) is not None
stdin = self._proc.stdin
assert stdin is not None
if stdin.write(data) != len(data):
self._proc.wait()
stderr = self._proc.stderr
Expand All @@ -1655,8 +1660,9 @@ def add_image(self, image: _NDArray) -> None:
def close(self) -> None:
"""Finishes writing the video. (Called automatically at end of context.)"""
if self._popen:
assert self._proc and self._proc.stdin and self._proc.stderr
assert (stdin := self._proc.stdin) is not None
assert self._proc, 'Error: closing an already closed context.'
stdin = self._proc.stdin
assert stdin is not None
stdin.close()
if self._proc.wait():
stderr = self._proc.stderr
Expand Down Expand Up @@ -1910,9 +1916,9 @@ def show_videos(
'Cannot have both a video dictionary and a titles parameter.'
)
list_titles = list(videos.keys())
list_videos: list[Iterable[_NDArray]] = list(videos.values())
list_videos = list(videos.values())
else:
list_videos = list(videos)
list_videos = list(typing.cast('Iterable[_NDArray]', videos))
list_titles = [None] * len(list_videos) if titles is None else list(titles)
if len(list_videos) != len(list_titles):
raise ValueError(
Expand All @@ -1926,10 +1932,8 @@ def show_videos(
for video, title in zip(list_videos, list_titles):
metadata: VideoMetadata | None = getattr(video, 'metadata', None)
first_image, video = _peek_first(video)
w, h = _get_width_height(
width, height, first_image.shape[:2] # type: ignore[arg-type]
)
if downsample and (w < first_image.shape[1] or h < first_image.shape[0]): # pytype: disable=attribute-error
w, h = _get_width_height(width, height, first_image.shape[:2])
if downsample and (w < first_image.shape[1] or h < first_image.shape[0]):
# Not resize_video() because each image may have different depth and type.
video = [resize_image(image, (h, w)) for image in video]
first_image = video[0]
Expand All @@ -1942,7 +1946,7 @@ def show_videos(
with _open(path, mode='wb') as f:
f.write(data)
if codec == 'gif':
pixelated = h > first_image.shape[0] or w > first_image.shape[1] # pytype: disable=attribute-error
pixelated = h > first_image.shape[0] or w > first_image.shape[1]
html_string = html_from_compressed_image(
data, w, h, title=title, fmt='gif', pixelated=pixelated, **kwargs
)
Expand Down
16 changes: 9 additions & 7 deletions mediapy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def test_to_uint8(self):
)

def test_color_ramp_float(self):
shape = (2, 3)
shape = 2, 3
image = media.color_ramp(shape=shape)
self.assert_all_equal(image.shape[:2], shape)
self.assert_all_close(
Expand All @@ -242,7 +242,7 @@ def test_color_ramp_float(self):
)

def test_color_ramp_uint8(self):
shape = (1, 3)
shape = 1, 3
image = media.color_ramp(shape=shape, dtype=np.uint8)
self.assert_all_equal(image.shape[:2], shape)
expected = [[
Expand Down Expand Up @@ -565,7 +565,7 @@ def test_compare_images(self):

@parameterized.parameters(False, True)
def test_video_non_streaming_write_read_roundtrip(self, use_generator):
shape = (240, 320)
shape = 240, 320
num_images = 10
fps = 40
qp = 20
Expand All @@ -577,14 +577,15 @@ def test_video_non_streaming_write_read_roundtrip(self, use_generator):
tmp_path = pathlib.Path(directory_name) / 'test.mp4'
media.write_video(tmp_path, video, fps=fps, qp=qp)
new_video = media.read_video(tmp_path)
assert new_video.metadata
self.assertEqual(new_video.metadata.num_images, num_images)
self.assertEqual(new_video.metadata.shape, shape)
self.assertEqual(new_video.metadata.fps, fps)
self.assertGreater(new_video.metadata.bps, 1_000)
self._check_similar(original_video, new_video, 3.0)

def test_video_streaming_write_read_roundtrip(self):
shape = (62, 744)
shape = 62, 744
num_images = 20
fps = 120
bps = 400_000
Expand All @@ -610,7 +611,7 @@ def test_video_streaming_write_read_roundtrip(self):
self._check_similar(images[index], new_image, 7.0, f'index={index}')

def test_video_streaming_read_write(self):
shape = (400, 400)
shape = 400, 400
num_images = 4
fps = 25
bps = 40_000_000
Expand All @@ -632,14 +633,15 @@ def test_video_streaming_read_write(self):
writer.add_image(image)

new_video = media.read_video(path2)
assert new_video.metadata
self.assertEqual(new_video.metadata.num_images, num_images)
self.assertEqual(new_video.metadata.shape, shape)
self.assertEqual(new_video.metadata.fps, fps)
self.assertGreater(new_video.metadata.bps, 1_000)
self._check_similar(video, new_video, 3.0)

def test_video_read_write_10bit(self):
shape = (256, 256)
shape = 256, 256
num_images = 4
fps = 60
bps = 40_000_000
Expand Down Expand Up @@ -690,7 +692,7 @@ def test_compress_decompress_video_roundtrip(self):
self._check_similar(video, new_video, max_rms=8.0)

def test_html_from_compressed_video(self):
shape = (240, 320)
shape = 240, 320
video = media.moving_circle(shape, 10)
text = media.html_from_compressed_video(
media.compress_video(video), shape[1], shape[0]
Expand Down
1 change: 0 additions & 1 deletion pdoc_files/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#!/usr/bin/env python3
"""Create HTML documentation from the source code using `pdoc`."""
# Note: Invoke this from the parent directory as "python3 pdoc_files/make.py".

Expand Down
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dev = [
"pylint>=2.6.0",
"pytest",
"pytest-xdist",
"pytype",
]

[build-system]
Expand Down Expand Up @@ -75,8 +76,8 @@ extend-exclude = "\\.ipynb"
disable = [
"unspecified-encoding", "line-too-long", "too-many-lines",
"too-few-public-methods", "too-many-locals", "too-many-instance-attributes",
"too-many-branches", "too-many-statements", "using-constant-test",
"wrong-import-order", "use-dict-literal",
"too-many-branches", "too-many-statements", "too-many-arguments",
"using-constant-test", "wrong-import-order", "use-dict-literal",
]
reports = false
score = false
Expand All @@ -88,3 +89,9 @@ good-names-rgxs = "^[a-z][a-z0-9]?|[A-Z]([A-Z_]*[A-Z])?$"
[tool.pylint.format]
indent-string = " "
expected-line-ending-format = "LF"

[tool.pytype]
keep_going = true
strict_none_binding = true
use_enum_overlay = true
use_fiddle_overlay = true