Skip to content

Commit 80149a1

Browse files
authored
Deduplicating media on Context creation (#1153)
1 parent c9bd88c commit 80149a1

File tree

8 files changed

+183
-19
lines changed

8 files changed

+183
-19
lines changed

packages/paper-qa-docling/src/paperqa_docling/reader.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import collections
22
import io
3+
import json
34
import os
45
from collections.abc import Mapping
56
from importlib.metadata import version
67
from pathlib import Path
7-
from typing import Any
8+
from typing import Any, cast
89

910
import docling
1011
from docling.datamodel.base_models import ConversionStatus
@@ -169,6 +170,19 @@ def parse_pdf_to_pages( # noqa: PLR0912
169170
f"Didn't yet handle 2+ picture description annotations {annotations}."
170171
)
171172

173+
media_metadata["info_hashable"] = json.dumps(
174+
{
175+
k: (
176+
v
177+
if k != "bbox"
178+
# Enables bbox deduplication based on whole pixels,
179+
# since <1-px differences are just noise
180+
else tuple(round(x) for x in cast(tuple, v))
181+
)
182+
for k, v in media_metadata.items()
183+
},
184+
sort_keys=True,
185+
)
172186
content[str(page_num)][1].append(
173187
ParsedMedia(
174188
index=len(content[str(page_num)][1]),
@@ -193,18 +207,32 @@ def parse_pdf_to_pages( # noqa: PLR0912
193207
table_image_data.save(img_bytes, format="PNG")
194208
img_bytes.seek(0) # Reset pointer before read to avoid empty data
195209

210+
media_metadata = {
211+
"type": "table",
212+
"width": table_image_data.width,
213+
"height": table_image_data.height,
214+
"bbox": item.prov[0].bbox.as_tuple(),
215+
"images_scale": pipeline_options.images_scale,
216+
}
217+
media_metadata["info_hashable"] = json.dumps(
218+
{
219+
k: (
220+
v
221+
if k != "bbox"
222+
# Enables bbox deduplication based on whole pixels,
223+
# since <1-px differences are just noise
224+
else tuple(round(x) for x in cast(tuple, v))
225+
)
226+
for k, v in media_metadata.items()
227+
},
228+
sort_keys=True,
229+
)
196230
content[str(page_num)][1].append(
197231
ParsedMedia(
198232
index=len(content[str(page_num)][1]),
199233
data=img_bytes.read(),
200234
text=item.export_to_markdown(doc),
201-
info={
202-
"type": "table",
203-
"width": table_image_data.width,
204-
"height": table_image_data.height,
205-
"bbox": item.prov[0].bbox.as_tuple(),
206-
"images_scale": pipeline_options.images_scale,
207-
},
235+
info=media_metadata,
208236
)
209237
)
210238
count_media += 1

packages/paper-qa-docling/tests/test_paperqa_docling.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,25 @@ def test_page_range() -> None:
173173
assert "page_range=(1,20)" in parsed_text_p1_20.metadata.name
174174

175175

176+
def test_media_deduplication() -> None:
177+
parsed_text = parse_pdf_to_pages(STUB_DATA_DIR / "duplicate_media.pdf")
178+
assert isinstance(parsed_text.content, dict)
179+
assert len(parsed_text.content) == 5, "Expected full PDF read"
180+
all_media = [m for _, media in parsed_text.content.values() for m in media] # type: ignore[misc]
181+
182+
all_images = [m for m in all_media if m.info.get("type") == "picture"]
183+
assert len(all_images) == 5, "Expected each image to be read"
184+
assert (
185+
len(set(all_images)) <= 2
186+
), "Expected images on all pages beyond 1 to be deduplicated"
187+
188+
all_tables = [m for m in all_media if m.info.get("type") == "table"]
189+
assert len(all_tables) == 5, "Expected each table to be read"
190+
assert (
191+
len(set(all_tables)) <= 2
192+
), "Expected tables on all pages beyond 1 to be deduplicated"
193+
194+
176195
def test_page_size_limit_denial() -> None:
177196
with pytest.raises(ImpossibleParsingError, match="char limit"):
178197
parse_pdf_to_pages(STUB_DATA_DIR / "paper.pdf", page_size_limit=10) # chars

src/paperqa/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,8 @@ async def _map_fxn_summary( # noqa: PLR0912
231231
# but not spaces, to preserve text alignment
232232
cleaned_text = text.text.strip("\n")
233233
if summary_llm_model and prompt_templates:
234-
media_text: list[str] = [m.text for m in text.media if m.text]
234+
unique_media = list(dict.fromkeys(text.media)) # Preserve order
235+
media_text: list[str] = [m.text for m in unique_media if m.text]
235236
data = {
236237
"question": question,
237238
"citation": citation,
@@ -254,8 +255,8 @@ async def _map_fxn_summary( # noqa: PLR0912
254255
Message.create_message(
255256
text=message_prompt,
256257
images=(
257-
[i.to_image_url() for i in text.media]
258-
if text.media
258+
[i.to_image_url() for i in unique_media]
259+
if unique_media
259260
else None
260261
),
261262
),

src/paperqa/types.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import re
1010
import warnings
11-
from collections.abc import Collection, Iterable, Mapping, Sequence
11+
from collections.abc import Collection, Hashable, Iterable, Mapping, Sequence
1212
from copy import deepcopy
1313
from datetime import UTC, datetime
1414
from enum import StrEnum
@@ -516,10 +516,14 @@ class ParsedMedia(BaseModel):
516516
),
517517
)
518518

519+
def _get_info_hashable(self) -> Hashable:
520+
if info_hashable := self.info.get("info_hashable"):
521+
return cast(Hashable, info_hashable)
522+
# We know info_hashable_hash key isn't present, so no need to filter it
523+
return json.dumps(self.info, sort_keys=True)
524+
519525
def __hash__(self) -> int:
520-
return hash(
521-
(self.index, self.data, self.text, json.dumps(self.info, sort_keys=True))
522-
)
526+
return hash((self.index, self.data, self.text, self._get_info_hashable()))
523527

524528
def to_id(self) -> UUID:
525529
"""Convert this media to a UUID4 suitable for a database ID."""
@@ -547,8 +551,7 @@ def __eq__(self, other) -> bool:
547551
self.index == other.index
548552
and self.data == other.data
549553
and self.text == other.text
550-
and json.dumps(self.info, sort_keys=True)
551-
== json.dumps(other.info, sort_keys=True)
554+
and self._get_info_hashable() == other._get_info_hashable()
552555
)
553556

554557
def to_image_url(self) -> str:

tests/duplicate_media_template.md

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# SF Districts in the style of Andy Warhol
2+
3+
<!-- pyml disable-num-lines 5 line-length -->
4+
5+
[//]: # "To generate `stub_data/duplicate_media.pdf` from this:"
6+
[//]: # "1. `pandoc duplicate_media_template.md --standalone --self-contained -t html -o temp.html`"
7+
[//]: # "2. `Chromium --headless --disable-gpu --print-to-pdf=stub_data/duplicate_media.pdf --no-pdf-header-footer temp.html`"
8+
[//]: # "3. `rm temp.html`"
9+
10+
<img src="stub_data/sf_districts.png" alt="Map of SF districts" height="200"/>
11+
12+
Text under image 1.
13+
14+
| Col1 | Col2 |
15+
| ----- | ----- |
16+
| Val11 | Val12 |
17+
| Val21 | Val11 |
18+
19+
Text under table 1.
20+
21+
<div style="page-break-after: always;"></div>
22+
23+
<img src="stub_data/sf_districts.png" alt="Map of SF districts" height="200"/>
24+
25+
Text under image 2.
26+
27+
| Col1 | Col2 |
28+
| ----- | ----- |
29+
| Val11 | Val12 |
30+
| Val21 | Val11 |
31+
32+
Text under table 2.
33+
34+
<div style="page-break-after: always;"></div>
35+
36+
<img src="stub_data/sf_districts.png" alt="Map of SF districts" height="200"/>
37+
38+
Text under image 3.
39+
40+
| Col1 | Col2 |
41+
| ----- | ----- |
42+
| Val11 | Val12 |
43+
| Val21 | Val11 |
44+
45+
Text under table 3.
46+
47+
<div style="page-break-after: always;"></div>
48+
49+
<img src="stub_data/sf_districts.png" alt="Map of SF districts" height="200"/>
50+
51+
Text under image 4.
52+
53+
| Col1 | Col2 |
54+
| ----- | ----- |
55+
| Val11 | Val12 |
56+
| Val21 | Val11 |
57+
58+
Text under table 4.
59+
60+
<div style="page-break-after: always;"></div>
61+
62+
<img src="stub_data/sf_districts.png" alt="Map of SF districts" height="200"/>
63+
64+
Text under image 5.
65+
66+
| Col1 | Col2 |
67+
| ----- | ----- |
68+
| Val11 | Val12 |
69+
| Val21 | Val11 |
70+
71+
Text under table 5.
168 KB
Binary file not shown.

tests/test_agents.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,12 @@ async def test_get_directory_index(
9090
"year",
9191
], "Incorrect fields in index"
9292
assert not index.changed, "Expected index to not have changes at this point"
93-
# bates.txt + empty.txt + flag_day.html + gravity_hill.md + influence.pdf + obama.txt + paper.pdf + pasa.pdf,
93+
# bates.txt + empty.txt + flag_day.html + gravity_hill.md + influence.pdf
94+
# + obama.txt + paper.pdf + pasa.pdf + duplicate_media.pdf,
9495
# but empty.txt fails to be added
9596
path_to_id = await index.index_files
9697
assert (
97-
sum(id_ != FAILED_DOCUMENT_ADD_ID for id_ in path_to_id.values()) == 7
98+
sum(id_ != FAILED_DOCUMENT_ADD_ID for id_ in path_to_id.values()) == 8
9899
), "Incorrect number of parsed index files"
99100

100101
with subtests.test(msg="check-txt-query"):
@@ -252,6 +253,7 @@ async def test_getting_manifest(
252253

253254
EXPECTED_STUB_DATA_FILES = {
254255
"bates.txt",
256+
"duplicate_media.pdf",
255257
"empty.txt",
256258
"flag_day.html",
257259
"gravity_hill.md",

tests/test_paperqa.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import csv
44
import io
5+
import json
56
import os
67
import pathlib
78
import pickle
@@ -35,6 +36,7 @@
3536
)
3637
from lmi.llms import rate_limited
3738
from lmi.utils import VCR_DEFAULT_MATCH_ON, validate_image
39+
from paperqa_docling import parse_pdf_to_pages as docling_parse_pdf_to_pages
3840
from paperqa_pymupdf import parse_pdf_to_pages as pymupdf_parse_pdf_to_pages
3941
from paperqa_pypdf import parse_pdf_to_pages as pypdf_parse_pdf_to_pages
4042
from pytest_subtests import SubTests
@@ -1693,6 +1695,44 @@ async def test_images(stub_data_dir: Path) -> None:
16931695
assert all(bool(c.used_images) for c in contexts_used) # type: ignore[attr-defined]
16941696

16951697

1698+
@pytest.mark.asyncio
1699+
async def test_duplicate_media_context_creation(stub_data_dir: Path) -> None:
1700+
settings = Settings(
1701+
prompts={"summary_json_system": summary_json_multimodal_system_prompt},
1702+
parsing={"parse_pdf": docling_parse_pdf_to_pages},
1703+
)
1704+
1705+
docs = Docs()
1706+
assert await docs.aadd(
1707+
stub_data_dir / "duplicate_media.pdf",
1708+
citation="FutureHouse, 2025, Accessed now", # Skip citation inference
1709+
title="SF Districts in the style of Andy Warhol", # Skip title inference
1710+
settings=settings,
1711+
)
1712+
with patch.object(
1713+
LLMModel, "call_single", side_effect=LLMModel.call_single, autospec=True
1714+
) as mock_call_single:
1715+
session = await docs.aquery(
1716+
"What districts neighbor the Western Addition?", settings=settings
1717+
)
1718+
context_user_msg = mock_call_single.await_args_list[0][1]["messages"][1]
1719+
assert isinstance(context_user_msg, Message)
1720+
assert context_user_msg.content
1721+
content_list = json.loads(context_user_msg.content)
1722+
assert isinstance(content_list, list)
1723+
assert (
1724+
sum("image_url" in x for x in content_list) < 5
1725+
), "Expected some deduplication to take place during context creation"
1726+
assert (
1727+
sum(
1728+
district in session.answer
1729+
for district in ("The Avenues", "Golden Gate", "Civic Center", "Haight")
1730+
)
1731+
>= 2
1732+
), "Expected at least two neighbors to be matched"
1733+
assert session.cost > 0
1734+
1735+
16961736
@pytest.mark.asyncio
16971737
async def test_images_corrupt(stub_data_dir: Path, caplog) -> None:
16981738
settings = Settings.from_name("fast")

0 commit comments

Comments
 (0)