Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions Tests/test_file_mpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,18 @@ def test_save_all() -> None:
assert_image_similar(im, im_reloaded, 30)

im = Image.new("RGB", (1, 1))
im2 = Image.new("RGB", (1, 1), "#f00")
im_reloaded = roundtrip(im, save_all=True, append_images=[im2])

assert_image_equal(im, im_reloaded)
assert isinstance(im_reloaded, MpoImagePlugin.MpoImageFile)
assert im_reloaded.mpinfo is not None
assert im_reloaded.mpinfo[45056] == b"0100"

im_reloaded.seek(1)
assert_image_similar(im2, im_reloaded, 1)
for colors in (("#f00",), ("#f00", "#0f0")):
append_images = [Image.new("RGB", (1, 1), color) for color in colors]
im_reloaded = roundtrip(im, save_all=True, append_images=append_images)

assert_image_equal(im, im_reloaded)
assert isinstance(im_reloaded, MpoImagePlugin.MpoImageFile)
assert im_reloaded.mpinfo is not None
assert im_reloaded.mpinfo[45056] == b"0100"

for im_expected in append_images:
im_reloaded.seek(im_reloaded.tell() + 1)
assert_image_similar(im_reloaded, im_expected, 1)

# Test that a single frame image will not be saved as an MPO
jpg = roundtrip(im, save_all=True)
Expand Down
13 changes: 9 additions & 4 deletions src/PIL/MpoImagePlugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#
from __future__ import annotations

import itertools
import os
import struct
from typing import IO, Any, cast
Expand Down Expand Up @@ -47,12 +46,18 @@ def _save_all(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None:

mpf_offset = 28
offsets: list[int] = []
for imSequence in itertools.chain([im], append_images):
for im_frame in ImageSequence.Iterator(imSequence):
im_sequences = [im, *append_images]
total = sum(getattr(seq, "n_frames", 1) for seq in im_sequences)
for im_sequence in im_sequences:
for im_frame in ImageSequence.Iterator(im_sequence):
if not offsets:
# APP2 marker
ifd_length = 66 + 16 * total
im_frame.encoderinfo["extra"] = (
b"\xff\xe2" + struct.pack(">H", 6 + 82) + b"MPF\0" + b" " * 82
b"\xff\xe2"
+ struct.pack(">H", 6 + ifd_length)
+ b"MPF\0"
+ b" " * ifd_length
)
exif = im_frame.encoderinfo.get("exif")
if isinstance(exif, Image.Exif):
Expand Down
Loading