Skip to content

Commit

Permalink
Merge pull request vanna-ai#411 from zyclove/main
Browse files Browse the repository at this point in the history
【feat】add auto create_index_if_not_exists
  • Loading branch information
zainhoda authored May 6, 2024
2 parents 7225854 + 0e6689f commit 7d04d3e
Showing 1 changed file with 98 additions and 4 deletions.
102 changes: 98 additions & 4 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,77 @@ def __init__(self, config=None):
self.document_index = document_index
self.ddl_index = ddl_index
self.question_sql_index = question_sql_index
print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)
print("OpenSearch_VectorStore initialized with document_index: ",
document_index, " ddl_index: ", ddl_index, " question_sql_index: ",
question_sql_index)

document_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"question": {
"type": "text",
},
"doc": {
"type": "text",
}
}
}
}

ddl_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"ddl": {
"type": "text",
},
"doc": {
"type": "text",
}
}
}
}

question_sql_index_settings = {
"settings": {
"index": {
"number_of_shards": 6,
"number_of_replicas": 2
}
},
"mappings": {
"properties": {
"question": {
"type": "text",
},
"sql": {
"type": "text",
}
}
}
}

if config is not None and "es_document_index_settings" in config:
document_index_settings = config["es_document_index_settings"]
if config is not None and "es_ddl_index_settings" in config:
ddl_index_settings = config["es_ddl_index_settings"]
if config is not None and "es_question_sql_index_settings" in config:
question_sql_index_settings = config["es_question_sql_index_settings"]

self.document_index_settings = document_index_settings
self.ddl_index_settings = ddl_index_settings
self.question_sql_index_settings = question_sql_index_settings

es_urls = None
if config is not None and "es_urls" in config:
Expand Down Expand Up @@ -85,6 +155,9 @@ def __init__(self, config=None):
else:
max_retries = 10

print("OpenSearch_VectorStore initialized with es_urls: ", es_urls,
" host: ", host, " port: ", port, " ssl: ", ssl, " verify_certs: ",
verify_certs, " timeout: ", timeout, " max_retries: ", max_retries)
if es_urls is not None:
# Initialize the OpenSearch client by passing a list of URLs
self.client = OpenSearch(
Expand Down Expand Up @@ -112,25 +185,47 @@ def __init__(self, config=None):
headers=headers
)

print("OpenSearch_VectorStore initialized with client over ")

# 执行一个简单的查询来检查连接
try:
print('Connected to OpenSearch cluster:')
info = self.client.info()
print('OpenSearch cluster info:', info)
except Exception as e:
print('Error connecting to OpenSearch cluster:', e)

# Create the indices if they don't exist
# self.create_index()
self.create_index_if_not_exists(self.document_index,
self.document_index_settings)
self.create_index_if_not_exists(self.ddl_index, self.ddl_index_settings)
self.create_index_if_not_exists(self.question_sql_index,
self.question_sql_index_settings)

def create_index(self):
for index in [self.document_index, self.ddl_index, self.question_sql_index]:
for index in [self.document_index, self.ddl_index,
self.question_sql_index]:
try:
self.client.indices.create(index)
except Exception as e:
print("Error creating index: ", e)
print(f"opensearch index {index} already exists")
pass

def create_index_if_not_exists(self, index_name: str,
index_settings: dict) -> bool:
try:
if not self.client.indices.exists(index_name):
print(f"Index {index_name} does not exist. Creating...")
self.client.indices.create(index=index_name, body=index_settings)
return True
else:
print(f"Index {index_name} already exists.")
return False
except Exception as e:
print(f"Error creating index: {index_name} ", e)
return False

def add_ddl(self, ddl: str, **kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
id = str(uuid.uuid4()) + "-ddl"
Expand Down Expand Up @@ -278,7 +373,6 @@ def generate_embedding(self, data: str, **kwargs) -> list[float]:
# opensearch doesn't need to generate embeddings
pass


# OpenSearch_VectorStore.__init__(self, config={'es_urls':
# "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
# "admin", 'es_password': "admin", 'es_verify_certs': True})
Expand Down

0 comments on commit 7d04d3e

Please sign in to comment.