-
Notifications
You must be signed in to change notification settings - Fork 2.5k
/
Copy pathvector_store_config.py
94 lines (74 loc) · 3.54 KB
/
vector_store_config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
"""Parameterization settings for the default configuration."""
from pydantic import BaseModel, Field, model_validator
import graphrag.config.defaults as defs
from graphrag.vector_stores.factory import VectorStoreType
class VectorStoreConfig(BaseModel):
"""The default configuration section for Vector Store."""
type: str = Field(
description="The vector store type to use.",
default=defs.VECTOR_STORE_TYPE,
)
db_uri: str | None = Field(description="The database URI to use.", default=None)
def _validate_db_uri(self) -> None:
"""Validate the database URI."""
if self.type == VectorStoreType.LanceDB.value and (
self.db_uri is None or self.db_uri.strip() == ""
):
self.db_uri = defs.VECTOR_STORE_DB_URI
if self.type != VectorStoreType.LanceDB.value and (
self.db_uri is not None and self.db_uri.strip() != ""
):
msg = "vector_store.db_uri is only used when vector_store.type == lancedb. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
url: str | None = Field(
description="The database URL when type == azure_ai_search.",
default=None,
)
def _validate_url(self) -> None:
"""Validate the database URL."""
if self.type == VectorStoreType.AzureAISearch and (
self.url is None or self.url.strip() == ""
):
msg = "vector_store.url is required when vector_store.type == azure_ai_search. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type == VectorStoreType.CosmosDB and (
self.url is None or self.url.strip() == ""
):
msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type == VectorStoreType.DocumentDB and (
self.url is None or self.url.strip() == ""
):
msg = "vector_store.url is required when vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
if self.type == VectorStoreType.LanceDB and (
self.url is not None and self.url.strip() != ""
):
msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db or vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type."
raise ValueError(msg)
api_key: str | None = Field(
description="The database API key when type == azure_ai_search.",
default=None,
)
audience: str | None = Field(
description="The database audience when type == azure_ai_search.",
default=None,
)
container_name: str = Field(
description="The container name to use.",
default=defs.VECTOR_STORE_CONTAINER_NAME,
)
database_name: str | None = Field(
description="The database name to use when type == cosmos_db or document_db.", default=None
)
overwrite: bool = Field(
description="Overwrite the existing data.", default=defs.VECTOR_STORE_OVERWRITE
)
@model_validator(mode="after")
def _validate_model(self):
"""Validate the model."""
self._validate_db_uri()
self._validate_url()
return self