forked from microsoft/autogen
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Graphrag integration (microsoft#4612)
* add initial global search draft * add graphrag dep * fix local search embedding * linting * add from config constructor * remove draft notebook * update config factory and add docstrings * add graphrag sample * add sample prompts * update readme * update deps * Add API docs * Update python/samples/agentchat_graphrag/requirements.txt * Update python/samples/agentchat_graphrag/requirements.txt * update docstrings with snippet and doc ref * lint * improve set up instructions in docstring * lint * update lock * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * add unit tests * update lock * update uv lock * add docstring newlines * stubs and typing on graphrag tests * fix docstrings * fix mypy error * + linting and type fixes * type fix graphrag sample * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_local_search.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/samples/agentchat_graphrag/requirements.txt Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * update overrides * fix docstring client imports * additional docstring fix * add docstring missing import * use openai and fix db path * use console for displaying messages * add model config and gitignore * update readme * lint * Update python/samples/agentchat_graphrag/README.md * Update python/samples/agentchat_graphrag/README.md * Comment remaining azure config --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
- Loading branch information
1 parent
8efe0c4
commit 95bd514
Showing
22 changed files
with
2,365 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 8 additions & 0 deletions
8
.../packages/autogen-core/docs/src/reference/python/autogen_ext.tools.graphrag.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
autogen\_ext.tools.graphrag | ||
=========================== | ||
|
||
|
||
.. automodule:: autogen_ext.tools.graphrag | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from ._config import ( | ||
GlobalContextConfig, | ||
GlobalDataConfig, | ||
LocalContextConfig, | ||
LocalDataConfig, | ||
MapReduceConfig, | ||
SearchConfig, | ||
) | ||
from ._global_search import GlobalSearchTool, GlobalSearchToolArgs, GlobalSearchToolReturn | ||
from ._local_search import LocalSearchTool, LocalSearchToolArgs, LocalSearchToolReturn | ||
|
||
__all__ = [ | ||
"GlobalSearchTool", | ||
"LocalSearchTool", | ||
"GlobalDataConfig", | ||
"LocalDataConfig", | ||
"GlobalContextConfig", | ||
"GlobalSearchToolArgs", | ||
"GlobalSearchToolReturn", | ||
"LocalContextConfig", | ||
"LocalSearchToolArgs", | ||
"LocalSearchToolReturn", | ||
"MapReduceConfig", | ||
"SearchConfig", | ||
] |
59 changes: 59 additions & 0 deletions
59
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_config.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class DataConfig(BaseModel): | ||
input_dir: str | ||
entity_table: str = "create_final_nodes" | ||
entity_embedding_table: str = "create_final_entities" | ||
community_level: int = 2 | ||
|
||
|
||
class GlobalDataConfig(DataConfig): | ||
community_table: str = "create_final_communities" | ||
community_report_table: str = "create_final_community_reports" | ||
|
||
|
||
class LocalDataConfig(DataConfig): | ||
relationship_table: str = "create_final_relationships" | ||
text_unit_table: str = "create_final_text_units" | ||
|
||
|
||
class ContextConfig(BaseModel): | ||
max_data_tokens: int = 8000 | ||
|
||
|
||
class GlobalContextConfig(ContextConfig): | ||
use_community_summary: bool = False | ||
shuffle_data: bool = True | ||
include_community_rank: bool = True | ||
min_community_rank: int = 0 | ||
community_rank_name: str = "rank" | ||
include_community_weight: bool = True | ||
community_weight_name: str = "occurrence weight" | ||
normalize_community_weight: bool = True | ||
max_data_tokens: int = 12000 | ||
|
||
|
||
class LocalContextConfig(ContextConfig): | ||
text_unit_prop: float = 0.5 | ||
community_prop: float = 0.25 | ||
include_entity_rank: bool = True | ||
rank_description: str = "number of relationships" | ||
include_relationship_weight: bool = True | ||
relationship_ranking_attribute: str = "rank" | ||
|
||
|
||
class MapReduceConfig(BaseModel): | ||
map_max_tokens: int = 1000 | ||
map_temperature: float = 0.0 | ||
reduce_max_tokens: int = 2000 | ||
reduce_temperature: float = 0.0 | ||
allow_general_knowledge: bool = False | ||
json_mode: bool = False | ||
response_type: str = "multiple paragraphs" | ||
|
||
|
||
class SearchConfig(BaseModel): | ||
max_tokens: int = 1500 | ||
temperature: float = 0.0 | ||
response_type: str = "multiple paragraphs" |
214 changes: 214 additions & 0 deletions
214
python/packages/autogen-ext/src/autogen_ext/tools/graphrag/_global_search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
# mypy: disable-error-code="no-any-unimported,misc" | ||
from pathlib import Path | ||
|
||
import pandas as pd | ||
import tiktoken | ||
from autogen_core import CancellationToken | ||
from autogen_core.tools import BaseTool | ||
from graphrag.config.config_file_loader import load_config_from_file | ||
from graphrag.query.indexer_adapters import ( | ||
read_indexer_communities, | ||
read_indexer_entities, | ||
read_indexer_reports, | ||
) | ||
from graphrag.query.llm.base import BaseLLM | ||
from graphrag.query.llm.get_client import get_llm | ||
from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext | ||
from graphrag.query.structured_search.global_search.search import GlobalSearch | ||
from pydantic import BaseModel, Field | ||
|
||
from ._config import GlobalContextConfig as ContextConfig | ||
from ._config import GlobalDataConfig as DataConfig | ||
from ._config import MapReduceConfig | ||
|
||
_default_context_config = ContextConfig() | ||
_default_mapreduce_config = MapReduceConfig() | ||
|
||
|
||
class GlobalSearchToolArgs(BaseModel): | ||
query: str = Field(..., description="The user query to perform global search on.") | ||
|
||
|
||
class GlobalSearchToolReturn(BaseModel): | ||
answer: str | ||
|
||
|
||
class GlobalSearchTool(BaseTool[GlobalSearchToolArgs, GlobalSearchToolReturn]): | ||
"""Enables running GraphRAG global search queries as an AutoGen tool. | ||
This tool allows you to perform semantic search over a corpus of documents using the GraphRAG framework. | ||
The search combines graph-based document relationships with semantic embeddings to find relevant information. | ||
.. note:: | ||
This tool requires the :code:`graphrag` extra for the :code:`autogen-ext` package. | ||
To install: | ||
.. code-block:: bash | ||
pip install -U "autogen-agentchat" "autogen-ext[graphrag]" | ||
Before using this tool, you must complete the GraphRAG setup and indexing process: | ||
1. Follow the GraphRAG documentation to initialize your project and settings | ||
2. Configure and tune your prompts for the specific use case | ||
3. Run the indexing process to generate the required data files | ||
4. Ensure you have the settings.yaml file from the setup process | ||
Please refer to the [GraphRAG documentation](https://microsoft.github.io/graphrag/) | ||
for detailed instructions on completing these prerequisite steps. | ||
Example usage with AssistantAgent: | ||
.. code-block:: python | ||
import asyncio | ||
from autogen_ext.models.openai import OpenAIChatCompletionClient | ||
from autogen_agentchat.ui import Console | ||
from autogen_ext.tools.graphrag import GlobalSearchTool | ||
from autogen_agentchat.agents import AssistantAgent | ||
async def main(): | ||
# Initialize the OpenAI client | ||
openai_client = OpenAIChatCompletionClient( | ||
model="gpt-4o-mini", | ||
api_key="<api-key>", | ||
) | ||
# Set up global search tool | ||
global_tool = GlobalSearchTool.from_settings(settings_path="./settings.yaml") | ||
# Create assistant agent with the global search tool | ||
assistant_agent = AssistantAgent( | ||
name="search_assistant", | ||
tools=[global_tool], | ||
model_client=openai_client, | ||
system_message=( | ||
"You are a tool selector AI assistant using the GraphRAG framework. " | ||
"Your primary task is to determine the appropriate search tool to call based on the user's query. " | ||
"For broader, abstract questions requiring a comprehensive understanding of the dataset, call the 'global_search' function." | ||
), | ||
) | ||
# Run a sample query | ||
query = "What is the overall sentiment of the community reports?" | ||
await Console(assistant_agent.run_stream(task=query)) | ||
if __name__ == "__main__": | ||
asyncio.run(main()) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
token_encoder: tiktoken.Encoding, | ||
llm: BaseLLM, | ||
data_config: DataConfig, | ||
context_config: ContextConfig = _default_context_config, | ||
mapreduce_config: MapReduceConfig = _default_mapreduce_config, | ||
): | ||
super().__init__( | ||
args_type=GlobalSearchToolArgs, | ||
return_type=GlobalSearchToolReturn, | ||
name="global_search_tool", | ||
description="Perform a global search with given parameters using graphrag.", | ||
) | ||
# Use the provided LLM | ||
self._llm = llm | ||
|
||
# Load parquet files | ||
community_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.community_table}.parquet") # type: ignore | ||
entity_df: pd.DataFrame = pd.read_parquet(f"{data_config.input_dir}/{data_config.entity_table}.parquet") # type: ignore | ||
report_df: pd.DataFrame = pd.read_parquet( # type: ignore | ||
f"{data_config.input_dir}/{data_config.community_report_table}.parquet" | ||
) | ||
entity_embedding_df: pd.DataFrame = pd.read_parquet( # type: ignore | ||
f"{data_config.input_dir}/{data_config.entity_embedding_table}.parquet" | ||
) | ||
|
||
communities = read_indexer_communities(community_df, entity_df, report_df) | ||
reports = read_indexer_reports(report_df, entity_df, data_config.community_level) | ||
entities = read_indexer_entities(entity_df, entity_embedding_df, data_config.community_level) | ||
|
||
context_builder = GlobalCommunityContext( | ||
community_reports=reports, | ||
communities=communities, | ||
entities=entities, | ||
token_encoder=token_encoder, | ||
) | ||
|
||
context_builder_params = { | ||
"use_community_summary": context_config.use_community_summary, | ||
"shuffle_data": context_config.shuffle_data, | ||
"include_community_rank": context_config.include_community_rank, | ||
"min_community_rank": context_config.min_community_rank, | ||
"community_rank_name": context_config.community_rank_name, | ||
"include_community_weight": context_config.include_community_weight, | ||
"community_weight_name": context_config.community_weight_name, | ||
"normalize_community_weight": context_config.normalize_community_weight, | ||
"max_tokens": context_config.max_data_tokens, | ||
"context_name": "Reports", | ||
} | ||
|
||
map_llm_params = { | ||
"max_tokens": mapreduce_config.map_max_tokens, | ||
"temperature": mapreduce_config.map_temperature, | ||
"response_format": {"type": "json_object"}, | ||
} | ||
|
||
reduce_llm_params = { | ||
"max_tokens": mapreduce_config.reduce_max_tokens, | ||
"temperature": mapreduce_config.reduce_temperature, | ||
} | ||
|
||
self._search_engine = GlobalSearch( | ||
llm=self._llm, | ||
context_builder=context_builder, | ||
token_encoder=token_encoder, | ||
max_data_tokens=context_config.max_data_tokens, | ||
map_llm_params=map_llm_params, | ||
reduce_llm_params=reduce_llm_params, | ||
allow_general_knowledge=mapreduce_config.allow_general_knowledge, | ||
json_mode=mapreduce_config.json_mode, | ||
context_builder_params=context_builder_params, | ||
concurrent_coroutines=32, | ||
response_type=mapreduce_config.response_type, | ||
) | ||
|
||
async def run(self, args: GlobalSearchToolArgs, cancellation_token: CancellationToken) -> GlobalSearchToolReturn: | ||
result = await self._search_engine.asearch(args.query) | ||
assert isinstance(result.response, str), "Expected response to be a string" | ||
return GlobalSearchToolReturn(answer=result.response) | ||
|
||
@classmethod | ||
def from_settings(cls, settings_path: str | Path) -> "GlobalSearchTool": | ||
"""Create a GlobalSearchTool instance from GraphRAG settings file. | ||
Args: | ||
settings_path: Path to the GraphRAG settings.yaml file | ||
Returns: | ||
An initialized GlobalSearchTool instance | ||
""" | ||
# Load GraphRAG config | ||
config = load_config_from_file(settings_path) | ||
|
||
# Initialize token encoder | ||
token_encoder = tiktoken.get_encoding(config.encoding_model) | ||
|
||
# Initialize LLM using graphrag's get_client | ||
llm = get_llm(config) | ||
|
||
# Create data config from storage paths | ||
data_config = DataConfig( | ||
input_dir=str(Path(config.storage.base_dir)), | ||
) | ||
|
||
return cls( | ||
token_encoder=token_encoder, | ||
llm=llm, | ||
data_config=data_config, | ||
context_config=_default_context_config, | ||
mapreduce_config=_default_mapreduce_config, | ||
) |
Oops, something went wrong.