-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python: Graphrag demo #10064
Draft
eavanvalkenburg
wants to merge
4
commits into
microsoft:main
Choose a base branch
from
eavanvalkenburg:graphrag_demo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+452
−0
Draft
Python: Graphrag demo #10064
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Graphrag Sample | ||
|
||
## Setup | ||
To setup this demo, make sure you are rooted in this root folder. | ||
|
||
Then run `setup-part-0` for the appropriate platform. | ||
This installs uv, creates a venv with python 3.12 in .venv, activates the venv, and installs the dependencies. | ||
|
||
### Linux or MacOS | ||
```bash | ||
./setup-part-0-linux.sh | ||
``` | ||
### Windows | ||
```powershell | ||
setup-part-0-windows.ps1 | ||
``` | ||
|
||
Next, run `setup-part-1.sh`, this will create the setup directory, downloads a book into it, and then runs the init script, which creates a settings.yaml and .env file. | ||
|
||
Next update the .env file with your OpenAI API key as the GRAPHRAG_API_KEY variable, if you want to use Azure OpenAI, then you need to update the settings.yaml accordingly, see the GraphRag docs for more info [here](https://github.com/microsoft/graphrag/blob/main/docs/get_started.md) | ||
|
||
Finally, run the `setup-part-2.sh` script, this will run the indexer, this will take a couple of minutes. | ||
|
||
Then run `python graphrag_chat.py` to chat with the book, inside that script are some options so feel free to change them to your liking. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
from .graphrag_chat_completion import GraphRagChatCompletion as GraphRagChatCompletion | ||
from .graphrag_prompt_execution_settings import GraphRagPromptExecutionSettings as GraphRagPromptExecutionSettings |
283 changes: 283 additions & 0 deletions
283
python/samples/demos/graphrag/SKGraphRag/graphrag_chat_completion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import logging | ||
from collections.abc import AsyncGenerator | ||
from typing import Any, Literal, override | ||
|
||
import graphrag.api as api | ||
import pandas as pd | ||
import yaml | ||
from graphrag.config.create_graphrag_config import GraphRagConfig, create_graphrag_config | ||
from graphrag.index.typing import PipelineRunResult | ||
|
||
from semantic_kernel.connectors.ai import PromptExecutionSettings | ||
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase | ||
from semantic_kernel.contents import AuthorRole, ChatHistory, ChatMessageContent, StreamingChatMessageContent | ||
|
||
from .graphrag_prompt_execution_settings import GraphRagPromptExecutionSettings | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GraphRagChatCompletion(ChatCompletionClientBase): | ||
"""GraphRagChatCompletion is a class that extends ChatCompletionClientBase to provide | ||
chat completion functionalities using a GraphRag setup. | ||
|
||
Attributes: | ||
project_directory (str): The directory where the project files are located. | ||
graphrag_config (GraphRagConfig): Configuration for the GraphRag service. | ||
final_nodes (pd.DataFrame | None): DataFrame containing the final nodes. | ||
final_entities (pd.DataFrame | None): DataFrame containing the final entities. | ||
final_communities (pd.DataFrame | None): DataFrame containing the final communities. | ||
final_community_reports (pd.DataFrame | None): DataFrame containing the final community reports. | ||
final_documents (pd.DataFrame | None): DataFrame containing the final documents. | ||
final_relationships (pd.DataFrame | None): DataFrame containing the final relationships. | ||
final_text_units (pd.DataFrame | None): DataFrame containing the final text units. | ||
""" | ||
|
||
project_directory: str | ||
graphrag_config: GraphRagConfig | ||
final_nodes: pd.DataFrame | None = None | ||
final_entities: pd.DataFrame | None = None | ||
final_communities: pd.DataFrame | None = None | ||
final_community_reports: pd.DataFrame | None = None | ||
final_documents: pd.DataFrame | None = None | ||
final_relationships: pd.DataFrame | None = None | ||
final_text_units: pd.DataFrame | None = None | ||
|
||
def __init__( | ||
self, project_directory: str, service_id: str = "graph_rag", graphrag_config: GraphRagConfig | None = None | ||
): | ||
""" | ||
Initializes the GraphRagChatCompletion instance. | ||
|
||
Args: | ||
project_directory (str): The directory where the project files are located. | ||
service_id (str): The service identifier. Defaults to "graph_rag". | ||
graphrag_config (GraphRagConfig | None): Configuration for the GraphRag service. | ||
If None, it will be loaded from settings.yaml. | ||
""" | ||
if not graphrag_config: | ||
with open(f"{project_directory}/settings.yaml") as file: | ||
graphrag_config = create_graphrag_config(values=yaml.safe_load(file), root_dir=project_directory) | ||
super().__init__( | ||
service_id=service_id, | ||
ai_model_id=service_id, | ||
project_directory=project_directory, | ||
graphrag_config=graphrag_config, | ||
) | ||
|
||
def get_prompt_execution_settings_class(self) -> type[PromptExecutionSettings]: | ||
""" | ||
Returns the class type for prompt execution settings. | ||
|
||
Returns: | ||
type[PromptExecutionSettings]: The class type for prompt execution settings. | ||
""" | ||
return GraphRagPromptExecutionSettings | ||
|
||
async def setup(self): | ||
"""Sets up the GraphRagChatCompletion instance by building the index and loading the necessary data.""" | ||
|
||
index_result: list[PipelineRunResult] = await api.build_index(config=self.graphrag_config) | ||
|
||
# index_result is a list of workflows that make up the indexing pipeline that was run | ||
for workflow_result in index_result: | ||
status = f"error\n{workflow_result.errors}" if workflow_result.errors else "success" | ||
print(f"Workflow Name: {workflow_result.workflow}\tStatus: {status}") | ||
self.load() | ||
|
||
def has_loaded(self, search_type: Literal["local", "global", "drift"] | None = None) -> bool: | ||
"""Checks if the necessary data has been loaded based on the search type. | ||
|
||
Args: | ||
search_type (Literal["local", "global", "drift"] | None): The type of search to check for. | ||
|
||
Returns: | ||
bool: True if the necessary data has been loaded, False otherwise. | ||
""" | ||
if search_type == "local": | ||
return all([ | ||
self.final_nodes is not None, | ||
self.final_entities is not None, | ||
self.final_communities is not None, | ||
self.final_community_reports is not None, | ||
self.final_text_units is not None, | ||
self.final_relationships is not None, | ||
]) | ||
if search_type == "global": | ||
return all([ | ||
self.final_nodes is not None, | ||
self.final_entities is not None, | ||
self.final_communities is not None, | ||
self.final_community_reports is not None, | ||
]) | ||
if search_type == "drift": | ||
return all([ | ||
self.final_nodes is not None, | ||
self.final_entities is not None, | ||
self.final_communities is not None, | ||
self.final_community_reports is not None, | ||
self.final_text_units is not None, | ||
self.final_relationships is not None, | ||
]) | ||
return all([ | ||
self.final_nodes is not None, | ||
self.final_entities is not None, | ||
self.final_communities is not None, | ||
self.final_community_reports is not None, | ||
self.final_text_units is not None, | ||
self.final_relationships is not None, | ||
self.final_documents is not None, | ||
]) | ||
|
||
def post_model_init(self, *args, **kwargs): | ||
"""Post-initialization method to load the necessary data after the model has been initialized.""" | ||
try: | ||
self.load() | ||
except FileNotFoundError: | ||
logger.warning( | ||
"Could not load the final nodes, entities, communities, and community reports. Please run setup first." | ||
) | ||
|
||
def load(self): | ||
"""Loads the parquet files. | ||
|
||
Includes final nodes, entities, communities, community reports, text units, relationships, and documents.""" | ||
|
||
self.final_nodes = pd.read_parquet(f"{self.project_directory}/output/create_final_nodes.parquet") | ||
self.final_entities = pd.read_parquet(f"{self.project_directory}/output/create_final_entities.parquet") | ||
self.final_communities = pd.read_parquet(f"{self.project_directory}/output/create_final_communities.parquet") | ||
self.final_community_reports = pd.read_parquet( | ||
f"{self.project_directory}/output/create_final_community_reports.parquet" | ||
) | ||
self.final_text_units = pd.read_parquet(f"{self.project_directory}/output/create_final_text_units.parquet") | ||
self.final_relationships = pd.read_parquet( | ||
f"{self.project_directory}/output/create_final_relationships.parquet" | ||
) | ||
self.final_documents = pd.read_parquet(f"{self.project_directory}/output/create_final_documents.parquet") | ||
|
||
@override | ||
async def _inner_get_chat_completion_contents( | ||
self, | ||
chat_history: "ChatHistory", | ||
settings: "PromptExecutionSettings", | ||
) -> list["ChatMessageContent"]: | ||
# Make sure the settings is of type GraphRagPromptExecutionSettings | ||
if not isinstance(settings, GraphRagPromptExecutionSettings): | ||
settings = self.get_prompt_execution_settings_from_settings(settings) | ||
# Check if the necessary data has been loaded | ||
if not self.has_loaded(search_type=settings.search_type): | ||
raise ValueError("The required assets have not been loaded, please run setup first.") | ||
if settings.search_type == "global": | ||
# Call the global search | ||
response, context = await api.global_search( | ||
config=self.graphrag_config, | ||
nodes=self.final_nodes, | ||
entities=self.final_entities, | ||
communities=self.final_communities, | ||
community_reports=self.final_community_reports, | ||
community_level=2, | ||
dynamic_community_selection=False, | ||
response_type=settings.response_type, | ||
query=chat_history.messages[-1].content, | ||
) | ||
# since the response is a string, we can wrap it into a ChatMessageContent | ||
# we store the context in the metadata of the message. | ||
if isinstance(response, str): | ||
cmc = ChatMessageContent(role=AuthorRole.ASSISTANT, content=response, metadata={"context": context}) | ||
return [cmc] | ||
raise ValueError("Unknown response type.") | ||
if settings.search_type == "local": | ||
# Call the local search | ||
response, context = await api.local_search( | ||
config=self.graphrag_config, | ||
nodes=self.final_nodes, | ||
entities=self.final_entities, | ||
community_reports=self.final_community_reports, | ||
text_units=self.final_text_units, | ||
relationships=self.final_relationships, | ||
covariates=None, | ||
community_level=2, | ||
response_type=settings.response_type, | ||
query=chat_history.messages[-1].content, | ||
) | ||
# since the response is a string, we can wrap it into a ChatMessageContent | ||
# we store the context in the metadata of the message. | ||
if isinstance(response, str): | ||
cmc = ChatMessageContent(role=AuthorRole.ASSISTANT, content=response, metadata={"context": context}) | ||
return [cmc] | ||
raise ValueError("Unknown response type.") | ||
# Call the drift search | ||
response, context = await api.drift_search( | ||
config=self.graphrag_config, | ||
nodes=self.final_nodes, | ||
entities=self.final_entities, | ||
community_reports=self.final_community_reports, | ||
text_units=self.final_text_units, | ||
relationships=self.final_relationships, | ||
community_level=2, | ||
query=chat_history.messages[-1].content, | ||
) | ||
# since the response is a string, we can wrap it into a ChatMessageContent | ||
# we store the context in the metadata of the message. | ||
if isinstance(response, str): | ||
cmc = ChatMessageContent(role=AuthorRole.ASSISTANT, content=response, metadata={"context": context}) | ||
return [cmc] | ||
raise ValueError("Unknown response type.") | ||
|
||
@override | ||
async def _inner_get_streaming_chat_message_contents( | ||
self, | ||
chat_history: "ChatHistory", | ||
settings: "PromptExecutionSettings", | ||
function_invoke_attempt: int = 0, | ||
) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: | ||
# Make sure the settings is of type GraphRagPromptExecutionSettings | ||
if not isinstance(settings, GraphRagPromptExecutionSettings): | ||
settings = self.get_prompt_execution_settings_from_settings(settings) | ||
# Check if the necessary data has been loaded | ||
if not self.has_loaded(search_type=settings.search_type): | ||
raise ValueError("The required assets have not been loaded, please run setup first.") | ||
if settings.search_type == "drift": | ||
# Drift search is not available when streaming | ||
raise NotImplementedError("Drift search is not available when streaming.") | ||
if settings.search_type == "global": | ||
# Call the global search | ||
responses = api.global_search_streaming( | ||
config=self.graphrag_config, | ||
nodes=self.final_nodes, | ||
entities=self.final_entities, | ||
communities=self.final_communities, | ||
community_reports=self.final_community_reports, | ||
community_level=2, | ||
dynamic_community_selection=False, | ||
response_type=settings.response_type, | ||
query=chat_history.messages[-1].content, | ||
) | ||
else: | ||
# Call the local search | ||
responses = api.local_search_streaming( | ||
config=self.graphrag_config, | ||
nodes=self.final_nodes, | ||
entities=self.final_entities, | ||
community_reports=self.final_community_reports, | ||
text_units=self.final_text_units, | ||
relationships=self.final_relationships, | ||
covariates=None, | ||
community_level=2, | ||
response_type=settings.response_type, | ||
query=chat_history.messages[-1].content, | ||
) | ||
async for response in responses: | ||
# the response is either a string (the response) or a dict (the context) | ||
if isinstance(response, str): | ||
# the response is a string, we can wrap it into a StreamingChatMessageContent | ||
cmc = StreamingChatMessageContent(choice_index=0, role=AuthorRole.ASSISTANT, content=response) | ||
yield [cmc] | ||
if isinstance(response, dict): | ||
# the response is a dict, we can add it to a metadata field of a StreamingChatMessageContent | ||
cmc = StreamingChatMessageContent( | ||
choice_index=0, content="", role=AuthorRole.ASSISTANT, metadata={"context": response} | ||
) | ||
yield [cmc] |
29 changes: 29 additions & 0 deletions
29
python/samples/demos/graphrag/SKGraphRag/graphrag_prompt_execution_settings.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright (c) Microsoft. All rights reserved. | ||
|
||
import logging | ||
from typing import Literal | ||
|
||
from semantic_kernel.connectors.ai import PromptExecutionSettings | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class GraphRagPromptExecutionSettings(PromptExecutionSettings): | ||
""" | ||
GraphRagPromptExecutionSettings is a class that inherits from PromptExecutionSettings | ||
and is used to configure the execution settings for a GraphRag prompt. | ||
|
||
Attributes: | ||
response_type: Specifies the type of response expected from the prompt. Default is "Multiple Paragraphs". | ||
Valid values for this can be found in the GraphRag doc. | ||
This value is used in a prompt to determine the format of the response, so has no fixed values. | ||
search_type: Specifies the type of search to be performed. | ||
- "global": see https://github.com/microsoft/graphrag/blob/main/docs/query/global_search.md | ||
- "local": see https://github.com/microsoft/graphrag/blob/main/docs/query/local_search.md | ||
- "drift": see https://github.com/microsoft/graphrag/blob/main/docs/query/drift_search.md | ||
Default is "global". | ||
Drift is not available for streaming completions. | ||
""" | ||
|
||
response_type: str = "Multiple Paragraphs" | ||
search_type: Literal["local", "global", "drift"] = "global" |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: This embeds the memory into the connector. Is this a recommended approach of doing rag with SK?