Skip to content

Commit ee543e4

Browse files
authored
feat(ocr): add async batch annotation for table extraction from PDFs (#301)
1 parent 9370820 commit ee543e4

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/usr/bin/env python
2+
import asyncio
3+
import json
4+
import os
5+
from typing import List
6+
7+
import httpx
8+
from pydantic import BaseModel, Field
9+
10+
from mistralai import Mistral
11+
from mistralai.extra import response_format_from_pydantic_model
12+
from mistralai.models import File
13+
14+
SAMPLE_PDF_URL = "https://arxiv.org/pdf/2401.04088"
15+
16+
17+
class Table(BaseModel):
18+
name: str = Field(description="The name or title of the table")
19+
20+
21+
class TableExtraction(BaseModel):
22+
tables: List[Table] = Field(description="List of tables found in the document")
23+
24+
25+
def create_ocr_batch_request(custom_id: str, document_url: str) -> dict:
26+
"""Batch requests require custom_id and body wrapper."""
27+
response_format = response_format_from_pydantic_model(TableExtraction)
28+
return {
29+
"custom_id": custom_id,
30+
"body": {
31+
"document": {"type": "document_url", "document_url": document_url},
32+
"document_annotation_format": response_format.model_dump(
33+
by_alias=True, exclude_none=True
34+
),
35+
"pages": [0, 1, 2, 3, 4, 5, 6, 7],
36+
"include_image_base64": False,
37+
},
38+
}
39+
40+
41+
async def main():
42+
client = Mistral(api_key=os.environ["MISTRAL_API_KEY"])
43+
44+
document_urls = [SAMPLE_PDF_URL]
45+
46+
batch_requests = [
47+
json.dumps(create_ocr_batch_request(custom_id=str(i), document_url=url))
48+
for i, url in enumerate(document_urls)
49+
]
50+
batch_content = "\n".join(batch_requests)
51+
52+
print("Uploading batch file...")
53+
batch_file = await client.files.upload_async(
54+
file=File(file_name="ocr_batch.jsonl", content=batch_content.encode()),
55+
purpose="batch",
56+
)
57+
print(f"Batch file uploaded: {batch_file.id}")
58+
59+
print("Creating batch job...")
60+
created_job = await client.batch.jobs.create_async(
61+
model="mistral-ocr-latest",
62+
input_files=[batch_file.id],
63+
endpoint="/v1/ocr",
64+
)
65+
print(f"Batch job created: {created_job.id}")
66+
67+
print("Waiting for job completion...")
68+
job = await client.batch.jobs.get_async(job_id=created_job.id)
69+
while job.status not in ["SUCCESS", "FAILED", "CANCELLED"]:
70+
print(f"Status: {job.status}")
71+
await asyncio.sleep(5)
72+
job = await client.batch.jobs.get_async(job_id=created_job.id)
73+
74+
print(f"Job status: {job.status}")
75+
76+
async with httpx.AsyncClient() as http_client:
77+
if job.output_file:
78+
signed_url = await client.files.get_signed_url_async(
79+
file_id=job.output_file
80+
)
81+
response = await http_client.get(signed_url.url)
82+
for line in response.content.decode().strip().split("\n"):
83+
result = json.loads(line)
84+
annotation = result["response"]["body"].get("document_annotation")
85+
if annotation:
86+
tables = TableExtraction.model_validate_json(annotation)
87+
for table in tables.tables:
88+
print(table.name)
89+
90+
if job.error_file:
91+
signed_url = await client.files.get_signed_url_async(file_id=job.error_file)
92+
response = await http_client.get(signed_url.url)
93+
print("Errors:", response.content.decode())
94+
95+
print("\nCleaning up...")
96+
await client.files.delete_async(file_id=batch_file.id)
97+
print("Done!")
98+
99+
100+
if __name__ == "__main__":
101+
asyncio.run(main())

0 commit comments

Comments
 (0)