Skip to content

get graph service to work with images #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions core/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ async def batch_retrieve_chunks(
auth: AuthContext,
folder_name: Optional[str] = None,
end_user_id: Optional[str] = None,
use_colpali: Optional[bool] = None,
retrieve_images: Optional[bool] = None,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter 'retrieve_images' creates inconsistency with similar parameter 'use_colpali' used in other methods

) -> List[ChunkResult]:
"""
Retrieve specific chunks by their document ID and chunk number in a single batch operation.
Expand All @@ -376,7 +376,7 @@ async def batch_retrieve_chunks(
auth: Authentication context
folder_name: Optional folder to scope the operation to
end_user_id: Optional end-user ID to scope the operation to
use_colpali: Whether to use colpali multimodal features for image chunks
retrieve_images: Whether to use colpali multimodal features for image chunks

Returns:
List of ChunkResult objects
Expand Down Expand Up @@ -404,7 +404,7 @@ async def batch_retrieve_chunks(
retrieval_tasks = [self.vector_store.get_chunks_by_id(chunk_identifiers)]

# Add colpali vector store task if needed
if use_colpali and self.colpali_vector_store:
if retrieve_images and self.colpali_vector_store:
logger.info("Preparing to retrieve chunks from both regular and colpali vector stores")
retrieval_tasks.append(self.colpali_vector_store.get_chunks_by_id(chunk_identifiers))

Expand Down
121 changes: 77 additions & 44 deletions core/services/graph_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,12 @@ async def _process_documents_for_entities(
for i, _ in enumerate(doc.chunk_ids)
]

# Batch retrieve chunks
chunks = await document_service.batch_retrieve_chunks(chunk_sources, auth)
# Batch retrieve chunks, including image chunks when available
chunks = await document_service.batch_retrieve_chunks(
chunk_sources,
auth,
retrieve_images=True,
)
logger.info(f"Retrieved {len(chunks)} chunks for processing")

# Process each chunk individually
Expand All @@ -545,7 +549,7 @@ async def _process_documents_for_entities(

# Extract entities and relationships from the chunk
chunk_entities, chunk_relationships = await self.extract_entities_from_text(
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides, override_is_image=chunk.metadata.get("is_image", False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change the method name to extract entities

)

# Store all initially extracted entities to track their IDs
Expand Down Expand Up @@ -695,6 +699,7 @@ async def extract_entities_from_text(
doc_id: str,
chunk_number: int,
prompt_overrides: Optional[EntityExtractionPromptOverride] = None,
override_is_image: Optional[bool] = None,
) -> Tuple[List[Entity], List[Relationship]]:
"""
Extract entities and relationships from text content using the LLM.
Expand All @@ -703,17 +708,43 @@ async def extract_entities_from_text(
content: Text content to process
doc_id: Document ID
chunk_number: Chunk number within the document
prompt_overrides: Optional EntityExtractionPromptOverride with customizations for prompts
override_is_image: Optional flag to override image detection based on metadata

Returns:
Tuple of (entities, relationships)
"""
settings = get_settings()

# Limit text length to avoid token limits
content_limited = content[: min(len(content), 5000)]

# We'll use the Pydantic model directly when calling litellm
# No need to generate JSON schema separately
# Determine if this chunk is an image based solely on metadata flag
is_image = override_is_image if override_is_image is not None else False
# For images, send full base64 content; for text, truncate to limit
if is_image:
content_limited = content
else:
content_limited = content[: min(len(content), 5000)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could remove this


# Build system message, differentiating between text and image inputs
if is_image:
# For images, instruct comprehensive visual and text interpretation
system_content = (
"You are a multi-modal extraction assistant for images. The input is a base64-encoded PNG of a document page. "
"Perform OCR to extract the text, and visually interpret all layout elements: tables, diagrams, charts, form fields, headings, and graphical icons. "
"Identify entities from both text and visuals, infer relationships depicted (e.g., flows, hierarchies, links), and draw logical conclusions from the combined information. "
"For entities, include their label and type (e.g., PERSON, ORGANIZATION, LOCATION, CONCEPT). "
"For relationships, output a JSON list of objects with source, target, and relationship fields. "
"Respond only with valid JSON representing the extracted entities and relationships."
)
else:
# For text, use standard extraction instructions
system_content = (
"You are an entity extraction and relationship extraction assistant. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the original system prompt pls! I tested it and works well. Also maybe for images, we could just append saying context is image. That should do the job.

"Extract entities and their relationships from the input text precisely and thoroughly. "
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). "
"For relationships, use a simple JSON format with source, target, and relationship fields. "
"Respond directly in valid JSON format, without any additional text or explanations."
)
system_message = {"role": "system", "content": system_content}

# Get entity extraction overrides if available
extraction_overrides = {}
Expand Down Expand Up @@ -744,45 +775,47 @@ async def extract_entities_from_text(
f"{json.dumps(examples_json, indent=2)}\n```\n"
)

# Modify the system message to handle properties as a string that will be parsed later
system_message = {
"role": "system",
"content": (
"You are an entity extraction and relationship extraction assistant. Extract entities and "
"their relationships from text precisely and thoroughly, extract as many entities and "
"relationships as possible. "
"For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, "
"CONCEPT, etc.). If the user has given examples, use those, these are just suggestions"
"For relationships, use a simple format with source, target, and relationship fields. "
"Be very through, there are many relationships that are not obvious"
"IMPORTANT: The source and target fields must be simple strings representing "
"entity labels. For example: "
"if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', "
"target: 'Entity B', relationship: 'relates to'. "
"Respond directly in json format, without any additional text or explanations. "
),
}
# Construct user message content as a list of content blocks
user_message_content = []

if is_image:
# For images, add the image as a content block
user_message_content.append({"type": "image_url", "image_url": {"url": content_limited}})
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base64 image content is incorrectly passed as a URL instead of a data URI when constructing the message for the LLM

# Add image-specific instructions as a text block
image_instructions = (
"Extract named entities and their relationships from the following image. "
"Perform OCR and visually interpret all layout elements. "
"Return your response as valid JSON.\n\n"
)
user_message_content.append({"type": "text", "text": image_instructions})
else:
# For text, add the text content block
text_instructions = (
"Extract named entities and their relationships from the following text. "
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). "
"For relationships, specify the source entity, target entity, and the relationship between them. "
'Sample relationship format: {"source": "Entity A", "target": "Entity B", '
'"relationship": "works for"}\n\n'
"Return your response as valid JSON:\n\n"
)
user_message_content.append({"type": "text", "text": text_instructions})
user_message_content.append({"type": "text", "text": content_limited})

# Use custom prompt if provided, otherwise use default
# Add examples if provided and not using a custom prompt template that handles them
if examples_str and not custom_prompt:
user_message_content.append({"type": "text", "text": examples_str})

# Use custom prompt template if provided
if custom_prompt:
user_message = {
"role": "user",
"content": custom_prompt.format(content=content_limited, examples=examples_str),
}
# If a custom prompt is provided, it takes precedence and formats the content
# We assume the custom prompt handles incorporating text and image content appropriately.
# For simplicity, we'll just pass the original content_limited and examples_str
# to the custom prompt formatter. The user is responsible for formatting in the template.
formatted_user_text = custom_prompt.format(content=content_limited, examples=examples_str)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Custom prompt handling doesn't account for multimodal image content

user_message = {"role": "user", "content": formatted_user_text}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for custom prompts, we should add the image!!!

else:
user_message = {
"role": "user",
"content": (
"Extract named entities and their relationships from the following text. "
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). "
"For relationships, specify the source entity, target entity, and the relationship between them. "
"The source and target must be simple strings matching the entity labels, not objects. "
f"{examples_str}"
'Sample relationship format: {"source": "Entity A", "target": "Entity B", '
'"relationship": "works for"}\n\n'
"Return your response as valid JSON:\n\n" + content_limited
),
}
# Otherwise, use the constructed list of content blocks
user_message = {"role": "user", "content": user_message_content}

# Get the model configuration from registered_models
model_config = settings.REGISTERED_MODELS.get(settings.GRAPH_MODEL, {})
Expand Down
72 changes: 72 additions & 0 deletions core/tests/unit/test_graph_service_image_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import asyncio
import pytest

from core.services.graph_service import GraphService, ExtractionResult, EntityExtraction, RelationshipExtraction

# Dummy settings for testing
class DummySettings:
GRAPH_MODEL = "dummy"
REGISTERED_MODELS = {"dummy": {"model_name": "test-model"}}

# Dummy instructor client to capture messages and return a simple ExtractionResult
class DummyClient:
def __init__(self):
self.captured_messages = None
self.chat = self
self.completions = self

async def create(self, model, messages, response_model, **kwargs):
self.captured_messages = messages
return ExtractionResult(
entities=[EntityExtraction(label="TestEntity", type="CONCEPT")],
relationships=[RelationshipExtraction(source="TestEntity", target="TestEntity", relationship="related_to")]
)

@pytest.fixture(autouse=True)
def patch_settings_and_instructor(monkeypatch):
# Patch get_settings to return DummySettings
import core.services.graph_service as gs_mod
monkeypatch.setattr(gs_mod, "get_settings", lambda: DummySettings())
# Prepare dummy instructor and litellm modules for dynamic import
import sys, types
dummy_client = DummyClient()
dummy_instructor = types.SimpleNamespace(
from_litellm=lambda ac, mode: dummy_client,
Mode=types.SimpleNamespace(JSON=None)
)
dummy_litellm = types.SimpleNamespace(acompletion='dummy')
# Insert into sys.modules so that import instructor/litellm picks up our dummy
monkeypatch.setitem(sys.modules, 'instructor', dummy_instructor)
monkeypatch.setitem(sys.modules, 'litellm', dummy_litellm)
return dummy_client

@pytest.mark.parametrize("content,expected_system_prefix", [
('data:image/png;base64,AAA', 'You are an entity extraction and relationship extraction assistant for images.'),
('Plain text content.', 'You are an entity extraction and relationship extraction assistant.'),
])
def test_system_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_system_prefix):
service = GraphService(db=None, embedding_model=None, completion_model=None)
entities, relationships = asyncio.run(
service.extract_entities_from_text(content, doc_id="doc1", chunk_number=0)
)
dummy = patch_settings_and_instructor
assert dummy.captured_messages is not None
system_msg, _ = dummy.captured_messages
assert system_msg['content'].startswith(expected_system_prefix)
assert entities and entities[0].label == "TestEntity"
assert relationships and relationships[0].type == "related_to"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attribute error: accessing 'type' but RelationshipExtraction uses 'relationship' attribute


@pytest.mark.parametrize("content,expected_user_prefix", [
('data:image/png;base64,BBB', 'Extract named entities and their relationships from the following base64-encoded image.'),
('Hello world', 'Extract named entities and their relationships from the following text.'),
])
def test_user_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_user_prefix):
service = GraphService(db=None, embedding_model=None, completion_model=None)
entities, relationships = asyncio.run(
service.extract_entities_from_text(content, doc_id="doc2", chunk_number=1)
)
dummy = patch_settings_and_instructor
_, user_msg = dummy.captured_messages
assert user_msg['content'].startswith(expected_user_prefix)
# Validate stub relationship
assert relationships and relationships[0].type == "related_to"