Skip to content
This repository was archived by the owner on Dec 16, 2024. It is now read-only.

Commit 1c5771f

Browse files
jerryjliuJerry Liu
andauthored
Add ability to retrieve indices from a graph (run-llama#367)
Co-authored-by: Jerry Liu <jerry@robustintelligence.com>
1 parent 11c96de commit 1c5771f

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

gpt_index/composability/graph.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,16 @@ def _get_default_index_registry() -> IndexRegistry:
5454
return index_registry
5555

5656

57+
def _safe_get_index_struct(
58+
docstore: DocumentStore, index_struct_id: str
59+
) -> IndexStruct:
60+
"""Try get index struct."""
61+
index_struct = docstore.get_document(index_struct_id)
62+
if not isinstance(index_struct, IndexStruct):
63+
raise ValueError("Invalid `index_struct_id` - must be an IndexStruct")
64+
return index_struct
65+
66+
5767
class ComposableGraph:
5868
"""Composable graph."""
5969

@@ -112,6 +122,18 @@ def query(
112122
)
113123
return query_runner.query(query_str, self._index_struct)
114124

125+
def get_index(
126+
self, index_struct_id: str, index_cls: Type[BaseGPTIndex], **kwargs: Any
127+
) -> BaseGPTIndex:
128+
"""Get index."""
129+
index_struct = _safe_get_index_struct(self._docstore, index_struct_id)
130+
return index_cls(
131+
index_struct=index_struct,
132+
docstore=self._docstore,
133+
index_registry=self._index_registry,
134+
**kwargs
135+
)
136+
115137
@classmethod
116138
def load_from_disk(cls, save_path: str, **kwargs: Any) -> "ComposableGraph":
117139
"""Load index from disk.
@@ -135,9 +157,9 @@ def load_from_disk(cls, save_path: str, **kwargs: Any) -> "ComposableGraph":
135157
docstore = DocumentStore.load_from_dict(
136158
result_dict["docstore"], index_registry.type_to_struct
137159
)
138-
index_struct = docstore.get_document(result_dict["index_struct_id"])
139-
if not isinstance(index_struct, IndexStruct):
140-
raise ValueError("Invalid `index_struct_id` - must be an IndexStruct")
160+
index_struct = _safe_get_index_struct(
161+
docstore, result_dict["index_struct_id"]
162+
)
141163
return cls(docstore, index_registry, index_struct, **kwargs)
142164

143165
def save_to_disk(self, save_path: str, **save_kwargs: Any) -> None:

gpt_index/indices/base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def __init__(
9393
self._index_registry = index_registry or IndexRegistry()
9494

9595
if index_struct is not None:
96+
if not isinstance(index_struct, self.index_struct_cls):
97+
raise ValueError(
98+
f"index_struct must be of type {self.index_struct_cls}"
99+
)
96100
self._index_struct = index_struct
97101
else:
98102
documents = cast(Sequence[DOCUMENTS_INPUT], documents)
@@ -230,7 +234,11 @@ def set_doc_id(self, doc_id: str) -> None:
230234
If you wish to delete the index struct, you can use this doc_id.
231235
232236
"""
237+
old_doc_id = self._index_struct.get_doc_id()
233238
self._index_struct.doc_id = doc_id
239+
# Note: we also need to delete old doc_id, and update docstore
240+
self._docstore.delete_document(old_doc_id)
241+
self._docstore.add_documents([self._index_struct])
234242

235243
def get_doc_id(self) -> str:
236244
"""Get doc_id for index struct.

tests/indices/query/test_recursive.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ def test_recursive_query_table_list(
210210
table2 = GPTSimpleKeywordTableIndex(documents[2:3], **table_kwargs)
211211
table1.set_text("table_summary1")
212212
table2.set_text("table_summary2")
213+
table1.set_doc_id("table1")
214+
table2.set_doc_id("table2")
213215

214216
list_index = GPTListIndex([table1, table2], **list_kwargs)
215217
query_str = "World?"
@@ -232,6 +234,11 @@ def test_recursive_query_table_list(
232234
response = graph.query(query_str, query_configs=query_configs)
233235
assert str(response) == ("Test?:This is a test.")
234236

237+
# test graph.get_index
238+
test_table1 = graph.get_index("table1", GPTSimpleKeywordTableIndex)
239+
response = test_table1.query("Hello")
240+
assert str(response) == ("Hello:Hello world.")
241+
235242

236243
@patch.object(TokenTextSplitter, "split_text", side_effect=mock_token_splitter_newline)
237244
@patch.object(LLMPredictor, "predict", side_effect=mock_llmpredictor_predict)

0 commit comments

Comments
 (0)