Skip to content

Commit

Permalink
add property extraction for KGs (#14707)
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored Jul 23, 2024
1 parent 1e361b5 commit a961dfb
Show file tree
Hide file tree
Showing 6 changed files with 497 additions and 138 deletions.
159 changes: 87 additions & 72 deletions docs/docs/examples/property_graph/Dynamic_KG_Extraction.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions llama-index-core/llama_index/core/graph_stores/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class EntityNode(LabelledNode):

def __str__(self) -> str:
"""Return the string representation of the node."""
if self.properties:
return f"{self.name} ({self.properties})"
return self.name

@property
Expand Down Expand Up @@ -91,6 +93,8 @@ class Relation(BaseModel):

def __str__(self) -> str:
"""Return the string representation of the relation."""
if self.properties:
return f"{self.label} ({self.properties})"
return self.label

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from llama_index.core.prompts.default_prompts import (
DEFAULT_DYNAMIC_EXTRACT_PROMPT,
DEFAULT_DYNAMIC_EXTRACT_PROPS_PROMPT,
)


Expand Down Expand Up @@ -51,6 +52,56 @@ def default_parse_dynamic_triplets(
return triplets


def default_parse_dynamic_triplets_with_props(
llm_output: str,
) -> List[Tuple[EntityNode, Relation, EntityNode]]:
"""
Parse the LLM output and convert it into a list of entity-relation-entity triplets.
This function is flexible and can handle various output formats.
Args:
llm_output (str): The output from the LLM, which may be JSON-like or plain text.
Returns:
List[Tuple[EntityNode, Relation, EntityNode]]: A list of triplets.
"""
triplets = []

# Regular expression to match the structure of each dictionary in the list
pattern = r"{'head': '(.*?)', 'head_type': '(.*?)', 'head_props': {(.*?)}, 'relation': '(.*?)', 'relation_props': {(.*?)}, 'tail': '(.*?)', 'tail_type': '(.*?)', 'tail_props': {(.*?)}}"

# Find all matches in the output
matches = re.findall(pattern, llm_output)

for match in matches:
(
head,
head_type,
head_props,
relation,
relation_props,
tail,
tail_type,
tail_props,
) = match

head_props = dict(re.findall(r"'(.*?)': '(.*?)'", head_props))
relation_props = dict(re.findall(r"'(.*?)': '(.*?)'", relation_props))
tail_props = dict(re.findall(r"'(.*?)': '(.*?)'", tail_props))

head_node = EntityNode(name=head, label=head_type, properties=head_props)
tail_node = EntityNode(name=tail, label=tail_type, properties=tail_props)
relation_node = Relation(
source_id=head_node.id,
target_id=tail_node.id,
label=relation,
properties=relation_props,
)
triplets.append((head_node, relation_node, tail_node))

return triplets


class DynamicLLMPathExtractor(TransformComponent):
"""
DynamicLLMPathExtractor is a component for extracting structured information from text
Expand All @@ -74,7 +125,13 @@ class DynamicLLMPathExtractor(TransformComponent):
num_workers (int): Number of workers for parallel processing.
max_triplets_per_chunk (int): Maximum number of triplets to extract per text chunk.
allowed_entity_types (List[str]): List of initial entity types for the ontology.
allowed_entity_props (Optional[Union[List[str], List[Tuple[str, str]]]]):
List of initial entity properties for the ontology.
Can be either property names or tuples of (name, description).
allowed_relation_types (List[str]): List of initial relation types for the ontology.
allowed_relation_props (Optional[Union[List[str], List[Tuple[str, str]]]]):
List of initial relation properties for the ontology.
Can be either property names or tuples of (name, description).
"""

llm: LLM
Expand All @@ -83,17 +140,23 @@ class DynamicLLMPathExtractor(TransformComponent):
num_workers: int
max_triplets_per_chunk: int
allowed_entity_types: List[str]
allowed_relation_types: List[str]
allowed_entity_props: List[str]
allowed_relation_types: Optional[List[str]]
allowed_relation_props: Optional[List[str]]

def __init__(
self,
llm: Optional[LLM] = None,
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
parse_fn: Callable = default_parse_dynamic_triplets,
parse_fn: Optional[Callable] = None,
max_triplets_per_chunk: int = 10,
num_workers: int = 4,
allowed_entity_types: Optional[List[str]] = None,
allowed_entity_props: Optional[Union[List[str], List[Tuple[str, str]]]] = None,
allowed_relation_types: Optional[List[str]] = None,
allowed_relation_props: Optional[
Union[List[str], List[Tuple[str, str]]]
] = None,
) -> None:
"""
Initialize the DynamicLLMPathExtractor.
Expand All @@ -113,7 +176,29 @@ def __init__(
extract_prompt = PromptTemplate(extract_prompt)

if extract_prompt is None:
extract_prompt = DEFAULT_DYNAMIC_EXTRACT_PROMPT
if allowed_entity_props is not None or allowed_relation_props is not None:
extract_prompt = DEFAULT_DYNAMIC_EXTRACT_PROPS_PROMPT
else:
extract_prompt = DEFAULT_DYNAMIC_EXTRACT_PROMPT

if parse_fn is None:
if allowed_entity_props is not None or allowed_relation_props is not None:
parse_fn = default_parse_dynamic_triplets_with_props
else:
parse_fn = default_parse_dynamic_triplets

# convert props to name -> description format if needed
if allowed_entity_props and isinstance(allowed_entity_props[0], tuple):
allowed_entity_props = [
f"Property `{k}` with description ({v})"
for k, v in allowed_entity_props
]

if allowed_relation_props and isinstance(allowed_relation_props[0], tuple):
allowed_relation_props = [
f"Property `{k}` with description ({v})"
for k, v in allowed_relation_props
]

super().__init__(
llm=llm or Settings.llm,
Expand All @@ -122,7 +207,9 @@ def __init__(
num_workers=num_workers,
max_triplets_per_chunk=max_triplets_per_chunk,
allowed_entity_types=allowed_entity_types or [],
allowed_entity_props=allowed_entity_props,
allowed_relation_types=allowed_relation_types or [],
allowed_relation_props=allowed_relation_props,
)

@classmethod
Expand All @@ -146,6 +233,56 @@ def __call__(
"""
return asyncio.run(self.acall(nodes, show_progress=show_progress, **kwargs))

async def _apredict_without_props(self, text: str) -> str:
"""
Asynchronously predict triples from text without properties.
Args:
text (str): The text to process.
Returns:
str: The predicted triples.
"""
return await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_triplets_per_chunk,
allowed_entity_types=", ".join(self.allowed_entity_types)
if len(self.allowed_entity_types) > 0
else "No entity types provided, You are free to define them.",
allowed_relation_types=", ".join(self.allowed_relation_types)
if len(self.allowed_relation_types) > 0
else "No relation types provided, You are free to define them.",
)

async def _apredict_with_props(self, text: str) -> str:
"""
Asynchronously predict triples from text with properties.
Args:
text (str): The text to process.
Returns:
str: The predicted triples.
"""
return await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_triplets_per_chunk,
allowed_entity_types=", ".join(self.allowed_entity_types)
if len(self.allowed_entity_types) > 0
else "No entity types provided, You are free to define them.",
allowed_relation_types=", ".join(self.allowed_relation_types)
if len(self.allowed_relation_types) > 0
else "No relation types provided, You are free to define them.",
allowed_entity_properties=", ".join(self.allowed_entity_props)
if self.allowed_entity_props
else "No entity properties provided, You are free to define them.",
allowed_relation_properties=", ".join(self.allowed_relation_props)
if self.allowed_relation_props
else "No relation properties provided, You are free to define them.",
)

async def _aextract(self, node: BaseNode) -> BaseNode:
"""
Asynchronously extract triples from a single node.
Expand All @@ -158,17 +295,14 @@ async def _aextract(self, node: BaseNode) -> BaseNode:
"""
text = node.get_content(metadata_mode="llm")
try:
llm_response = await self.llm.apredict(
self.extract_prompt,
text=text,
max_knowledge_triplets=self.max_triplets_per_chunk,
allowed_entity_types=", ".join(self.allowed_entity_types)
if len(self.allowed_entity_types) > 0
else "No entity types provided, You are free to define them.",
allowed_relation_types=", ".join(self.allowed_relation_types)
if len(self.allowed_relation_types) > 0
else "No relation types provided, You are free to define them.",
)
if (
self.allowed_entity_props is not None
and self.allowed_relation_props is not None
):
llm_response = await self._apredict_with_props(text)
else:
llm_response = await self._apredict_without_props(text)

triplets = self.parse_fn(llm_response)
except Exception as e:
print(f"Error during extraction: {e!s}")
Expand All @@ -179,9 +313,9 @@ async def _aextract(self, node: BaseNode) -> BaseNode:

metadata = node.metadata.copy()
for subj, rel, obj in triplets:
subj.properties = metadata
obj.properties = metadata
rel.properties = metadata
subj.properties.update(metadata)
obj.properties.update(metadata)
rel.properties.update(metadata)

existing_nodes.extend([subj, obj])
existing_relations.append(rel)
Expand Down
Loading

0 comments on commit a961dfb

Please sign in to comment.