-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy path_graphrag_test.py
64 lines (51 loc) · 2.78 KB
/
_graphrag_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# type: ignore
import unittest
from dataclasses import dataclass
from unittest.mock import AsyncMock, MagicMock, patch
from fast_graphrag._graphrag import BaseGraphRAG
from fast_graphrag._models import TAnswer
from fast_graphrag._types import TContext, TQueryResponse
class TestBaseGraphRAG(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.llm_service = AsyncMock()
self.chunking_service = AsyncMock()
self.information_extraction_service = MagicMock()
self.information_extraction_service.extract_entities_from_query = AsyncMock()
self.state_manager = AsyncMock()
self.state_manager.embedding_service.embedding_dim = self.state_manager.entity_storage.embedding_dim = 1
@dataclass
class BaseGraphRAGNoEmbeddingValidation(BaseGraphRAG):
def __post_init__(self):
pass
self.graph_rag = BaseGraphRAGNoEmbeddingValidation(
working_dir="test_dir",
domain="test_domain",
example_queries="test_query",
entity_types=["type1", "type2"],
)
self.graph_rag.llm_service = self.llm_service
self.graph_rag.chunking_service = self.chunking_service
self.graph_rag.information_extraction_service = self.information_extraction_service
self.graph_rag.state_manager = self.state_manager
async def test_async_insert(self):
self.chunking_service.extract = AsyncMock(return_value=["chunked_data"])
self.state_manager.filter_new_chunks = AsyncMock(return_value=["new_chunks"])
self.information_extraction_service.extract = MagicMock(return_value=["subgraph"])
self.state_manager.upsert = AsyncMock()
await self.graph_rag.async_insert("test_content", {"meta": "data"})
self.chunking_service.extract.assert_called_once()
self.state_manager.filter_new_chunks.assert_called_once()
self.information_extraction_service.extract.assert_called_once()
self.state_manager.upsert.assert_called_once()
@patch("fast_graphrag._graphrag.format_and_send_prompt", new_callable=AsyncMock)
async def test_async_query(self, format_and_send_prompt):
self.information_extraction_service.extract_entities_from_query = AsyncMock(return_value=["entities"])
self.state_manager.get_context = AsyncMock(return_value=TContext([], [], []))
format_and_send_prompt.return_value=(TAnswer(answer="response"), None)
response = await self.graph_rag.async_query("test_query")
self.information_extraction_service.extract_entities_from_query.assert_called_once()
self.state_manager.get_context.assert_called_once()
format_and_send_prompt.assert_called_once()
self.assertIsInstance(response, TQueryResponse)
if __name__ == "__main__":
unittest.main()