-
Notifications
You must be signed in to change notification settings - Fork 221
get graph service to work with images #112
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -530,8 +530,12 @@ async def _process_documents_for_entities( | |
for i, _ in enumerate(doc.chunk_ids) | ||
] | ||
|
||
# Batch retrieve chunks | ||
chunks = await document_service.batch_retrieve_chunks(chunk_sources, auth) | ||
# Batch retrieve chunks, including image chunks when available | ||
chunks = await document_service.batch_retrieve_chunks( | ||
chunk_sources, | ||
auth, | ||
retrieve_images=True, | ||
) | ||
logger.info(f"Retrieved {len(chunks)} chunks for processing") | ||
|
||
# Process each chunk individually | ||
|
@@ -545,7 +549,7 @@ async def _process_documents_for_entities( | |
|
||
# Extract entities and relationships from the chunk | ||
chunk_entities, chunk_relationships = await self.extract_entities_from_text( | ||
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides | ||
chunk.content, chunk.document_id, chunk.chunk_number, extraction_overrides, override_is_image=chunk.metadata.get("is_image", False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should change the method name to extract entities |
||
) | ||
|
||
# Store all initially extracted entities to track their IDs | ||
|
@@ -695,6 +699,7 @@ async def extract_entities_from_text( | |
doc_id: str, | ||
chunk_number: int, | ||
prompt_overrides: Optional[EntityExtractionPromptOverride] = None, | ||
override_is_image: Optional[bool] = None, | ||
) -> Tuple[List[Entity], List[Relationship]]: | ||
""" | ||
Extract entities and relationships from text content using the LLM. | ||
|
@@ -703,17 +708,43 @@ async def extract_entities_from_text( | |
content: Text content to process | ||
doc_id: Document ID | ||
chunk_number: Chunk number within the document | ||
prompt_overrides: Optional EntityExtractionPromptOverride with customizations for prompts | ||
override_is_image: Optional flag to override image detection based on metadata | ||
|
||
Returns: | ||
Tuple of (entities, relationships) | ||
""" | ||
settings = get_settings() | ||
|
||
# Limit text length to avoid token limits | ||
content_limited = content[: min(len(content), 5000)] | ||
|
||
# We'll use the Pydantic model directly when calling litellm | ||
# No need to generate JSON schema separately | ||
# Determine if this chunk is an image based solely on metadata flag | ||
is_image = override_is_image if override_is_image is not None else False | ||
# For images, send full base64 content; for text, truncate to limit | ||
if is_image: | ||
content_limited = content | ||
else: | ||
content_limited = content[: min(len(content), 5000)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could remove this |
||
|
||
# Build system message, differentiating between text and image inputs | ||
if is_image: | ||
# For images, instruct comprehensive visual and text interpretation | ||
system_content = ( | ||
"You are a multi-modal extraction assistant for images. The input is a base64-encoded PNG of a document page. " | ||
"Perform OCR to extract the text, and visually interpret all layout elements: tables, diagrams, charts, form fields, headings, and graphical icons. " | ||
"Identify entities from both text and visuals, infer relationships depicted (e.g., flows, hierarchies, links), and draw logical conclusions from the combined information. " | ||
"For entities, include their label and type (e.g., PERSON, ORGANIZATION, LOCATION, CONCEPT). " | ||
"For relationships, output a JSON list of objects with source, target, and relationship fields. " | ||
"Respond only with valid JSON representing the extracted entities and relationships." | ||
) | ||
else: | ||
# For text, use standard extraction instructions | ||
system_content = ( | ||
"You are an entity extraction and relationship extraction assistant. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should keep the original system prompt pls! I tested it and works well. Also maybe for images, we could just append saying context is image. That should do the job. |
||
"Extract entities and their relationships from the input text precisely and thoroughly. " | ||
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " | ||
"For relationships, use a simple JSON format with source, target, and relationship fields. " | ||
"Respond directly in valid JSON format, without any additional text or explanations." | ||
) | ||
system_message = {"role": "system", "content": system_content} | ||
|
||
# Get entity extraction overrides if available | ||
extraction_overrides = {} | ||
|
@@ -744,45 +775,47 @@ async def extract_entities_from_text( | |
f"{json.dumps(examples_json, indent=2)}\n```\n" | ||
) | ||
|
||
# Modify the system message to handle properties as a string that will be parsed later | ||
system_message = { | ||
"role": "system", | ||
"content": ( | ||
"You are an entity extraction and relationship extraction assistant. Extract entities and " | ||
"their relationships from text precisely and thoroughly, extract as many entities and " | ||
"relationships as possible. " | ||
"For entities, include entity label and type (some examples: PERSON, ORGANIZATION, LOCATION, " | ||
"CONCEPT, etc.). If the user has given examples, use those, these are just suggestions" | ||
"For relationships, use a simple format with source, target, and relationship fields. " | ||
"Be very through, there are many relationships that are not obvious" | ||
"IMPORTANT: The source and target fields must be simple strings representing " | ||
"entity labels. For example: " | ||
"if you extract entities 'Entity A' and 'Entity B', a relationship would have source: 'Entity A', " | ||
"target: 'Entity B', relationship: 'relates to'. " | ||
"Respond directly in json format, without any additional text or explanations. " | ||
), | ||
} | ||
# Construct user message content as a list of content blocks | ||
user_message_content = [] | ||
|
||
if is_image: | ||
# For images, add the image as a content block | ||
user_message_content.append({"type": "image_url", "image_url": {"url": content_limited}}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Base64 image content is incorrectly passed as a URL instead of a data URI when constructing the message for the LLM |
||
# Add image-specific instructions as a text block | ||
image_instructions = ( | ||
"Extract named entities and their relationships from the following image. " | ||
"Perform OCR and visually interpret all layout elements. " | ||
"Return your response as valid JSON.\n\n" | ||
) | ||
user_message_content.append({"type": "text", "text": image_instructions}) | ||
else: | ||
# For text, add the text content block | ||
text_instructions = ( | ||
"Extract named entities and their relationships from the following text. " | ||
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " | ||
"For relationships, specify the source entity, target entity, and the relationship between them. " | ||
'Sample relationship format: {"source": "Entity A", "target": "Entity B", ' | ||
'"relationship": "works for"}\n\n' | ||
"Return your response as valid JSON:\n\n" | ||
) | ||
user_message_content.append({"type": "text", "text": text_instructions}) | ||
user_message_content.append({"type": "text", "text": content_limited}) | ||
|
||
# Use custom prompt if provided, otherwise use default | ||
# Add examples if provided and not using a custom prompt template that handles them | ||
if examples_str and not custom_prompt: | ||
user_message_content.append({"type": "text", "text": examples_str}) | ||
|
||
# Use custom prompt template if provided | ||
if custom_prompt: | ||
user_message = { | ||
"role": "user", | ||
"content": custom_prompt.format(content=content_limited, examples=examples_str), | ||
} | ||
# If a custom prompt is provided, it takes precedence and formats the content | ||
# We assume the custom prompt handles incorporating text and image content appropriately. | ||
# For simplicity, we'll just pass the original content_limited and examples_str | ||
# to the custom prompt formatter. The user is responsible for formatting in the template. | ||
formatted_user_text = custom_prompt.format(content=content_limited, examples=examples_str) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Custom prompt handling doesn't account for multimodal image content |
||
user_message = {"role": "user", "content": formatted_user_text} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even for custom prompts, we should add the image!!! |
||
else: | ||
user_message = { | ||
"role": "user", | ||
"content": ( | ||
"Extract named entities and their relationships from the following text. " | ||
"For entities, include entity label and type (PERSON, ORGANIZATION, LOCATION, CONCEPT, etc.). " | ||
"For relationships, specify the source entity, target entity, and the relationship between them. " | ||
"The source and target must be simple strings matching the entity labels, not objects. " | ||
f"{examples_str}" | ||
'Sample relationship format: {"source": "Entity A", "target": "Entity B", ' | ||
'"relationship": "works for"}\n\n' | ||
"Return your response as valid JSON:\n\n" + content_limited | ||
), | ||
} | ||
# Otherwise, use the constructed list of content blocks | ||
user_message = {"role": "user", "content": user_message_content} | ||
|
||
# Get the model configuration from registered_models | ||
model_config = settings.REGISTERED_MODELS.get(settings.GRAPH_MODEL, {}) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import asyncio | ||
import pytest | ||
|
||
from core.services.graph_service import GraphService, ExtractionResult, EntityExtraction, RelationshipExtraction | ||
|
||
# Dummy settings for testing | ||
class DummySettings: | ||
GRAPH_MODEL = "dummy" | ||
REGISTERED_MODELS = {"dummy": {"model_name": "test-model"}} | ||
|
||
# Dummy instructor client to capture messages and return a simple ExtractionResult | ||
class DummyClient: | ||
def __init__(self): | ||
self.captured_messages = None | ||
self.chat = self | ||
self.completions = self | ||
|
||
async def create(self, model, messages, response_model, **kwargs): | ||
self.captured_messages = messages | ||
return ExtractionResult( | ||
entities=[EntityExtraction(label="TestEntity", type="CONCEPT")], | ||
relationships=[RelationshipExtraction(source="TestEntity", target="TestEntity", relationship="related_to")] | ||
) | ||
|
||
@pytest.fixture(autouse=True) | ||
def patch_settings_and_instructor(monkeypatch): | ||
# Patch get_settings to return DummySettings | ||
import core.services.graph_service as gs_mod | ||
monkeypatch.setattr(gs_mod, "get_settings", lambda: DummySettings()) | ||
# Prepare dummy instructor and litellm modules for dynamic import | ||
import sys, types | ||
dummy_client = DummyClient() | ||
dummy_instructor = types.SimpleNamespace( | ||
from_litellm=lambda ac, mode: dummy_client, | ||
Mode=types.SimpleNamespace(JSON=None) | ||
) | ||
dummy_litellm = types.SimpleNamespace(acompletion='dummy') | ||
# Insert into sys.modules so that import instructor/litellm picks up our dummy | ||
monkeypatch.setitem(sys.modules, 'instructor', dummy_instructor) | ||
monkeypatch.setitem(sys.modules, 'litellm', dummy_litellm) | ||
return dummy_client | ||
|
||
@pytest.mark.parametrize("content,expected_system_prefix", [ | ||
('', 'You are an entity extraction and relationship extraction assistant for images.'), | ||
('Plain text content.', 'You are an entity extraction and relationship extraction assistant.'), | ||
]) | ||
def test_system_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_system_prefix): | ||
service = GraphService(db=None, embedding_model=None, completion_model=None) | ||
entities, relationships = asyncio.run( | ||
service.extract_entities_from_text(content, doc_id="doc1", chunk_number=0) | ||
) | ||
dummy = patch_settings_and_instructor | ||
assert dummy.captured_messages is not None | ||
system_msg, _ = dummy.captured_messages | ||
assert system_msg['content'].startswith(expected_system_prefix) | ||
assert entities and entities[0].label == "TestEntity" | ||
assert relationships and relationships[0].type == "related_to" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Attribute error: accessing 'type' but RelationshipExtraction uses 'relationship' attribute |
||
|
||
@pytest.mark.parametrize("content,expected_user_prefix", [ | ||
('', 'Extract named entities and their relationships from the following base64-encoded image.'), | ||
('Hello world', 'Extract named entities and their relationships from the following text.'), | ||
]) | ||
def test_user_prompt_for_image_vs_text(patch_settings_and_instructor, content, expected_user_prefix): | ||
service = GraphService(db=None, embedding_model=None, completion_model=None) | ||
entities, relationships = asyncio.run( | ||
service.extract_entities_from_text(content, doc_id="doc2", chunk_number=1) | ||
) | ||
dummy = patch_settings_and_instructor | ||
_, user_msg = dummy.captured_messages | ||
assert user_msg['content'].startswith(expected_user_prefix) | ||
# Validate stub relationship | ||
assert relationships and relationships[0].type == "related_to" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parameter 'retrieve_images' creates inconsistency with similar parameter 'use_colpali' used in other methods