Skip to content

feat: transforms for Knowledge Graphs #1345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions src/ragas/experimental/testset/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import json
import typing as t
import uuid
from dataclasses import dataclass, field
from pathlib import Path

from pydantic import BaseModel, Field


class UUIDEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, uuid.UUID):
return str(o)
return super().default(o)


class Node(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
properties: dict = Field(default_factory=dict)
type: str = ""

# a simple repr
def __repr__(self) -> str:
return f"Node(id: {str(self.id)[:6]}, type: {self.type}, properties: {list(self.properties.keys())})"

def __str__(self) -> str:
return self.__repr__()

def add_property(self, key: str, value: t.Any):
if key.lower() in self.properties:
raise ValueError(f"Property {key} already exists")
self.properties[key.lower()] = value

def get_property(self, key: str) -> t.Optional[t.Any]:
return self.properties.get(key.lower(), None)

def __hash__(self) -> int:
return hash(self.id)

def __eq__(self, other: object) -> bool:
if isinstance(other, Node):
return self.id == other.id
return False


class Relationship(BaseModel):
id: uuid.UUID = Field(default_factory=uuid.uuid4)
type: str
source: Node
target: Node
bidirectional: bool = False
properties: dict = Field(default_factory=dict)

def get_property(self, key: str) -> t.Optional[t.Any]:
return self.properties.get(key.lower(), None)

def __repr__(self) -> str:
return f"Relationship(Node(id: {str(self.source.id)[:6]}) {'<->' if self.bidirectional else '->'} Node(id: {str(self.target.id)[:6]}), type: {self.type}, properties: {list(self.properties.keys())})"

def __str__(self) -> str:
return self.__repr__()

def __hash__(self) -> int:
return hash(self.id)

def __eq__(self, other: object) -> bool:
if isinstance(other, Relationship):
return self.id == other.id
return False


@dataclass
class KnowledgeGraph:
nodes: t.List[Node] = field(default_factory=list)
relationships: t.List[Relationship] = field(default_factory=list)

def add(self, item: t.Union[Node, Relationship]):
if isinstance(item, Node):
self._add_node(item)
elif isinstance(item, Relationship):
self._add_relationship(item)
else:
raise ValueError(f"Invalid item type: {type(item)}")

def _add_node(self, node: Node):
self.nodes.append(node)

def _add_relationship(self, relationship: Relationship):
self.relationships.append(relationship)

def save(self, path: t.Union[str, Path]):
if isinstance(path, str):
path = Path(path)

data = {
"nodes": [node.model_dump() for node in self.nodes],
"relationships": [rel.model_dump() for rel in self.relationships],
}
with open(path, "w") as f:
json.dump(data, f, cls=UUIDEncoder, indent=2)

@classmethod
def load(cls, path: t.Union[str, Path]) -> "KnowledgeGraph":
if isinstance(path, str):
path = Path(path)

with open(path, "r") as f:
data = json.load(f)

nodes = [Node(**node_data) for node_data in data["nodes"]]
relationships = [Relationship(**rel_data) for rel_data in data["relationships"]]

kg = cls()
kg.nodes.extend(nodes)
kg.relationships.extend(relationships)
return kg

def __repr__(self) -> str:
return f"KnowledgeGraph(nodes: {len(self.nodes)}, relationships: {len(self.relationships)})"

def __str__(self) -> str:
return self.__repr__()

def find_clusters(
self, relationship_condition: t.Callable[[Relationship], bool] = lambda _: True
) -> t.List[t.Set[Node]]:
clusters = []
visited = set()

relationships = [
rel for rel in self.relationships if relationship_condition(rel)
]

def dfs(node: Node, cluster: t.Set[Node]):
visited.add(node)
cluster.add(node)
for rel in relationships:
if rel.source == node and rel.target not in visited:
dfs(rel.target, cluster)
# if the relationship is bidirectional, we need to check the reverse
elif (
rel.bidirectional
and rel.target == node
and rel.source not in visited
):
dfs(rel.source, cluster)

for node in self.nodes:
if node not in visited:
cluster = set()
dfs(node, cluster)
if len(cluster) > 1:
clusters.append(cluster)

return clusters
212 changes: 212 additions & 0 deletions src/ragas/experimental/testset/transforms/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import typing as t
from abc import ABC, abstractmethod
from dataclasses import dataclass, field

from ragas.experimental.testset.graph import KnowledgeGraph, Node, Relationship
from ragas.llms import BaseRagasLLM, llm_factory


class BaseGraphTransformations(ABC):
"""
Abstract base class for graph transformations on a KnowledgeGraph.
"""

@abstractmethod
async def transform(self, kg: KnowledgeGraph) -> t.Any:
"""
Abstract method to transform the KnowledgeGraph. Transformations should be
idempotent, meaning that applying the transformation multiple times should
yield the same result as applying it once.

Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.

Returns
-------
t.Any
The transformed knowledge graph.
"""
pass

def filter(self, kg: KnowledgeGraph) -> KnowledgeGraph:
"""
Filters the KnowledgeGraph and returns the filtered graph.

Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be filtered.

Returns
-------
KnowledgeGraph
The filtered knowledge graph.
"""
return kg


class Extractor(BaseGraphTransformations):
"""
Abstract base class for extractors that transform a KnowledgeGraph by extracting
specific properties from its nodes.

Methods
-------
transform(kg: KnowledgeGraph) -> t.List[t.Tuple[Node, t.Tuple[str, t.Any]]]
Transforms the KnowledgeGraph by extracting properties from its nodes.

extract(node: Node) -> t.Tuple[str, t.Any]
Abstract method to extract a specific property from a node.
"""

async def transform(
self, kg: KnowledgeGraph
) -> t.List[t.Tuple[Node, t.Tuple[str, t.Any]]]:
"""
Transforms the KnowledgeGraph by extracting properties from its nodes. Uses
the `filter` method to filter the graph and the `extract` method to extract
properties from each node.

Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.

Returns
-------
t.List[t.Tuple[Node, t.Tuple[str, t.Any]]]
A list of tuples where each tuple contains a node and the extracted
property.

Examples
--------
>>> kg = KnowledgeGraph(nodes=[Node(id=1, properties={"name": "Node1"}), Node(id=2, properties={"name": "Node2"})])
>>> extractor = SomeConcreteExtractor()
>>> extractor.transform(kg)
[(Node(id=1, properties={"name": "Node1"}), ("property_name", "extracted_value")),
(Node(id=2, properties={"name": "Node2"}), ("property_name", "extracted_value"))]
"""
filtered = self.filter(kg)
return [(node, await self.extract(node)) for node in filtered.nodes]

@abstractmethod
async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
"""
Abstract method to extract a specific property from a node.

Parameters
----------
node : Node
The node from which to extract the property.

Returns
-------
t.Tuple[str, t.Any]
A tuple containing the property name and the extracted value.
"""
pass


@dataclass
class LLMBasedExtractor(Extractor):
llm: BaseRagasLLM = field(default_factory=llm_factory)
merge_if_possible: bool = True


class Splitter(BaseGraphTransformations):
"""
Abstract base class for splitters that transform a KnowledgeGraph by splitting
its nodes into smaller chunks.

Methods
-------
transform(kg: KnowledgeGraph) -> t.Tuple[t.List[Node], t.List[Relationship]]
Transforms the KnowledgeGraph by splitting its nodes into smaller chunks.

split(node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]
Abstract method to split a node into smaller chunks.
"""

async def transform(
self, kg: KnowledgeGraph
) -> t.Tuple[t.List[Node], t.List[Relationship]]:
"""
Transforms the KnowledgeGraph by splitting its nodes into smaller chunks.

Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.

Returns
-------
t.Tuple[t.List[Node], t.List[Relationship]]
A tuple containing a list of new nodes and a list of new relationships.
"""
filtered = self.filter(kg)

all_nodes = []
all_relationships = []
for node in filtered.nodes:
nodes, relationships = await self.split(node)
all_nodes.extend(nodes)
all_relationships.extend(relationships)

return all_nodes, all_relationships

@abstractmethod
async def split(self, node: Node) -> t.Tuple[t.List[Node], t.List[Relationship]]:
"""
Abstract method to split a node into smaller chunks.

Parameters
----------
node : Node
The node to be split.

Returns
-------
t.Tuple[t.List[Node], t.List[Relationship]]
A tuple containing a list of new nodes and a list of new relationships.
"""
pass


class RelationshipBuilder(BaseGraphTransformations):
"""
Abstract base class for building relationships in a KnowledgeGraph.

Methods
-------
transform(kg: KnowledgeGraph) -> t.List[Relationship]
Abstract method to transform the KnowledgeGraph by building relationships.
"""

@abstractmethod
async def transform(self, kg: KnowledgeGraph) -> t.List[Relationship]:
"""
Transforms the KnowledgeGraph by building relationships.

Parameters
----------
kg : KnowledgeGraph
The knowledge graph to be transformed.

Returns
-------
t.List[Relationship]
A list of new relationships.
"""
pass


class Parallel:
def __init__(self, *transformations: BaseGraphTransformations):
self.transformations = list(transformations)


class Sequences:
def __init__(self, *transformations: t.Union[BaseGraphTransformations, Parallel]):
self.transformations = list(transformations)
Empty file.
25 changes: 25 additions & 0 deletions src/ragas/experimental/testset/transforms/extractors/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import typing as t
from dataclasses import dataclass

from ragas.embeddings import BaseRagasEmbeddings, embedding_factory
from ragas.experimental.testset.graph import Node
from ragas.experimental.testset.transforms.base import Extractor


@dataclass
class EmbeddingExtractor(Extractor):
model: str = "text-embedding-3-small"
property_name: str = "embedding"
embedding_model: BaseRagasEmbeddings = embedding_factory(model=model)

async def extract(self, node: Node) -> t.Tuple[str, t.Any]:
text = node.get_property("page_content")
if not isinstance(text, str):
raise ValueError(
f"node.property('page_content') must be a string, found '{type(text)}'"
)
embedding = self.embedding_model.embed_query(text)
return self.property_name, embedding


embedding_extractor = EmbeddingExtractor()
Loading
Loading