-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy path_models_test.py
126 lines (99 loc) · 4.41 KB
/
_models_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# type: ignore
import unittest
from pydantic import ValidationError
from fast_graphrag._models import (
TEditRelation,
TEditRelationList,
TQueryEntities,
dump_to_csv,
dump_to_reference_list,
)
from fast_graphrag._types import TEntity
class TestModels(unittest.TestCase):
def test_tqueryentities(self):
query_entities = TQueryEntities(named=["Entity1", "Entity2"], generic=["Generic1", "Generic2"])
self.assertEqual(query_entities.named, ["ENTITY1", "ENTITY2"])
self.assertEqual(query_entities.generic, ["Generic1", "Generic2"])
with self.assertRaises(ValidationError):
TQueryEntities(entities=["Entity1", "Entity2"], n="two")
def test_teditrelationship(self):
edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description")
self.assertEqual(edit_relationship.ids, [1, 2])
self.assertEqual(edit_relationship.description, "Combined relationship description")
def test_teditrelationshiplist(self):
edit_relationship = TEditRelation(ids=[1, 2], description="Combined relationship description")
edit_relationship_list = TEditRelationList(grouped_facts=[edit_relationship])
self.assertEqual(edit_relationship_list.groups, [edit_relationship])
def test_dump_to_csv(self):
data = [TEntity(name="Sample name", type="SAMPLE TYPE", description="Sample description")]
fields = ["name", "type"]
values = {"score": [0.9]}
csv_output = dump_to_csv(data, fields, with_header=True, **values)
expected_output = ["name\ttype\tscore", "Sample name\tSAMPLE TYPE\t0.9"]
self.assertEqual(csv_output, expected_output)
class TestDumpToReferenceList(unittest.TestCase):
def test_empty_list(self):
self.assertEqual(dump_to_reference_list([]), [])
def test_single_element(self):
self.assertEqual(dump_to_reference_list(["item"]), ["[1] item\n=====\n\n"])
def test_multiple_elements(self):
data = ["item1", "item2", "item3"]
expected = [
"[1] item1\n=====\n\n",
"[2] item2\n=====\n\n",
"[3] item3\n=====\n\n"
]
self.assertEqual(dump_to_reference_list(data), expected)
def test_custom_separator(self):
data = ["item1", "item2"]
separator = " | "
expected = [
"[1] item1 | ",
"[2] item2 | "
]
self.assertEqual(dump_to_reference_list(data, separator), expected)
class TestDumpToCsv(unittest.TestCase):
def test_empty_data(self):
self.assertEqual(dump_to_csv([], ["field1", "field2"]), [])
def test_single_element(self):
class Data:
def __init__(self, field1, field2):
self.field1 = field1
self.field2 = field2
data = [Data("value1", "value2")]
expected = ["value1\tvalue2"]
self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected)
def test_multiple_elements(self):
class Data:
def __init__(self, field1, field2):
self.field1 = field1
self.field2 = field2
data = [Data("value1", "value2"), Data("value3", "value4")]
expected = ["value1\tvalue2", "value3\tvalue4"]
self.assertEqual(dump_to_csv(data, ["field1", "field2"]), expected)
def test_with_header(self):
class Data:
def __init__(self, field1, field2):
self.field1 = field1
self.field2 = field2
data = [Data("value1", "value2")]
expected = ["field1\tfield2", "value1\tvalue2"]
self.assertEqual(dump_to_csv(data, ["field1", "field2"], with_header=True), expected)
def test_custom_separator(self):
class Data:
def __init__(self, field1, field2):
self.field1 = field1
self.field2 = field2
data = [Data("value1", "value2")]
expected = ["value1 | value2"]
self.assertEqual(dump_to_csv(data, ["field1", "field2"], separator=" | "), expected)
def test_additional_values(self):
class Data:
def __init__(self, field1, field2):
self.field1 = field1
self.field2 = field2
data = [Data("value1", "value2")]
expected = ["value1\tvalue2\tvalue3"]
self.assertEqual(dump_to_csv(data, ["field1", "field2"], value3=["value3"]), expected)
if __name__ == "__main__":
unittest.main()