Skip to content

Commit

Permalink
ES与OpenSearch重构数据库列表及表列表-v0.8-beta (#2780)
Browse files Browse the repository at this point in the history
  • Loading branch information
feiazifeiazi authored Sep 19, 2024
1 parent 5a6871d commit b4c4827
Show file tree
Hide file tree
Showing 3 changed files with 350 additions and 69 deletions.
214 changes: 152 additions & 62 deletions sql/engines/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ElasticsearchEngineBase(EngineBase):
"""

def __init__(self, instance=None):
self.conn = None # type: Elasticsearch # 使用类型注释来显式提示类型
self.db_separator = "__" # 设置分隔符
# 限制只能2种支持的子类
self.search_name = ["Elasticsearch", "OpenSearch"]
Expand All @@ -105,16 +106,11 @@ def get_all_databases(self):
# 获取所有的别名,没有别名就是本身。
indices = self.conn.indices.get_alias(index=self.db_name)
database_names = set()
if self.db_name == "*":
database_names.add("system") # 系统表名使用的库名
database_names.add("system") # 系统表名使用的库名
for index_name in indices.keys():
if self.db_separator in index_name:
db_name = index_name.split(self.db_separator)[0]
database_names.add(db_name)
elif index_name.startswith(".kibana_"):
database_names.add("system_kibana")
elif index_name.startswith(".internal."):
database_names.add("system_internal")
database_names.add("other") # 表名没有__时,使用的库名
database_names_sorted = sorted(database_names)
return ResultSet(rows=database_names_sorted)
Expand All @@ -123,17 +119,17 @@ def get_all_databases(self):
raise Exception(f"获取数据库时出错: {str(e)}")

def get_all_tables(self, db_name, **kwargs):
"""根据给定的数据库名获取所有相关的表名"""
"""根据给定的数据库名获取所有相关的表名
以点开头的表名,不返回。此为系统表,官方不让查询了。
"""
try:
self.get_connection()
indices = self.conn.indices.get_alias(index=self.db_name)
tables = set()

db_mapping = {
"system_kibana": ".kibana_",
"system_internal": ".internal.",
"system": ".",
"other": "other",
"system": "",
"other": "",
}
# 根据分隔符分隔的库名
if db_name not in db_mapping:
Expand All @@ -142,27 +138,28 @@ def get_all_tables(self, db_name, **kwargs):
index for index in indices.keys() if index.startswith(index_prefix)
]
else:
# 处理系统表,和other,循环db_mapping.items() 很难实现。
# 处理系统表,和other
if db_name == "system":
# 将系统的API作为表名
tables.add("/_cat/indices/" + self.db_name)
tables.add("/_cat/nodes")
tables.add("/_security/role")
tables.add("/_security/user")

for index_name in indices.keys():
if index_name.startswith(".kibana_") | index_name.startswith(
".kibana-"
):
if db_name == "system_kibana":
tables.add(index_name)
continue
elif index_name.startswith(".internal."):
if db_name == "system_internal":
tables.add(index_name)
continue
elif index_name.startswith("."):
if db_name == "system":
tables.add(index_name)
if index_name.startswith("."):
# if db_name == "system":
# tables.add(index_name)
continue
elif index_name.startswith(db_name):
tables.add(index_name)
if db_name == "system":
tables.add("/_cat/indices/" + db_name)
continue
elif self.db_separator in index_name:
continue
separator_db_name = index_name.split(self.db_separator)[0]
if db_name == "system":
tables.add("/_cat/indices/" + separator_db_name)
else:
if db_name == "other":
tables.add(index_name)
Expand All @@ -174,47 +171,53 @@ def get_all_tables(self, db_name, **kwargs):
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""获取所有字段"""
result_set = ResultSet(full_sql=f"{tb_name}/_mapping")
try:
self.get_connection()
mapping = self.conn.indices.get_mapping(index=tb_name)
properties = (
mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
)
# 返回字段名
result_set.column_list = ["column_name"]
if properties is None:
result_set.rows = [("无")]
else:
result_set.rows = list(properties.keys())
if tb_name.startswith(("/", "_")):
return result_set
except Exception as e:
raise Exception(f"获取字段时出错: {str(e)}")
else:
try:
self.get_connection()
mapping = self.conn.indices.get_mapping(index=tb_name)
properties = (
mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
)
# 返回字段名
result_set.column_list = ["column_name"]
if properties is None:
result_set.rows = [("无")]
else:
result_set.rows = list(properties.keys())
return result_set
except Exception as e:
raise Exception(f"获取字段时出错: {str(e)}")

def describe_table(self, db_name, tb_name, **kwargs):
"""表结构"""
result_set = ResultSet(full_sql=f"{tb_name}/_mapping")
try:
self.get_connection()
mapping = self.conn.indices.get_mapping(index=tb_name)
properties = (
mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
)
# 创建包含字段名、类型和其他信息的列表结构
result_set.column_list = ["column_name", "type", "fields"]
if properties is None:
result_set.rows = [("无", "无", "无")]
else:
result_set.rows = [
(
column,
details.get("type"),
json.dumps(details.get("fields", {})),
)
for column, details in properties.items()
]
if tb_name.startswith(("/", "_")):
return result_set
except Exception as e:
raise Exception(f"获取字段时出错: {str(e)}")
else:
try:
self.get_connection()
mapping = self.conn.indices.get_mapping(index=tb_name)
properties = (
mapping.get(tb_name, {}).get("mappings", {}).get("properties", None)
)
# 创建包含字段名、类型和其他信息的列表结构
result_set.column_list = ["column_name", "type", "fields"]
if properties is None:
result_set.rows = [("无", "无", "无")]
else:
result_set.rows = [
(
column,
details.get("type"),
json.dumps(details.get("fields", {})),
)
for column, details in properties.items()
]
return result_set
except Exception as e:
raise Exception(f"获取字段时出错: {str(e)}")

def query_check(self, db_name=None, sql=""):
"""语句检查"""
Expand Down Expand Up @@ -298,7 +301,7 @@ def query(
self.get_connection()
# 管理查询处理
if query_params.path.startswith("/_cat/indices"):
# v这个参数用显示标题,需要加上。
# v这个参数用显示标题,需要加上。 opensearch 需要字符串的true
if "v" not in query_params.params:
query_params.params["v"] = "true"
response = self.conn.cat.indices(
Expand All @@ -318,6 +321,10 @@ def query(
result_set.column_list = []
result_set.rows = []
result_set.affected_rows = 0
elif query_params.path.startswith("/_security/role"):
result_set = self._security_role(sql, query_params)
elif query_params.path.startswith("/_security/user"):
result_set = self._security_user(sql, query_params)
elif query_params.sql and self.name == "Elasticsearch":
query_body = {"query": query_params.sql}
response = self.conn.sql.query(body=query_body)
Expand Down Expand Up @@ -409,6 +416,12 @@ def query(
except Exception as e:
raise Exception(f"执行查询时出错: {str(e)}")

def _security_role(self, sql, query_params: QueryParamsSearch):
"""角色查询方法。请子类实现。"""

def _security_user(self, sql, query_params: QueryParamsSearch):
"""用户查询方法。请子类实现。"""

def parse_cat_indices_response(self, response_text):
"""解析cat indices结果"""
# 将响应文本按行分割
Expand Down Expand Up @@ -479,6 +492,12 @@ def parse_es_select_query_to_query_params(
index_pattern = path_parts[3]
if not index_pattern:
index_pattern = "*"
elif path.startswith("/_security/role"):
path_parts = path.split("/")
index_pattern = "*"
elif path.startswith("/_security/user"):
path_parts = path.split("/")
index_pattern = "*"
elif "/_search" in path:
# 默认情况,处理常规索引路径
# 提取索引名称
Expand Down Expand Up @@ -1053,11 +1072,20 @@ def get_connection(self, db_name=None):
raise Exception("Elasticsearch 连接无法建立。")
return self.conn

def _security_role(self, sql, query_params: QueryParamsSearch):
"""TODO 角色查询方法。"""
raise NotImplementedError("此方法暂未实现。")

def _security_user(self, sql, query_params: QueryParamsSearch):
"""TODO 用户查询方法。"""
raise NotImplementedError("此方法暂未实现。")


class OpenSearchEngine(ElasticsearchEngineBase):
"""OpenSearch 引擎实现"""

def __init__(self, instance=None):
self.conn = None # type: OpenSearch # 使用类型注释来显式提示类型
super().__init__(instance=instance)

name: str = "OpenSearch"
Expand Down Expand Up @@ -1093,3 +1121,65 @@ def get_connection(self, db_name=None):
if not self.conn:
raise Exception("OpenSearch 连接无法建立。")
return self.conn

def _security_role(self, sql, query_params: QueryParamsSearch):
"""角色查询方法。"""
result_set = ResultSet(full_sql=sql)
url = "/_opendistro/_security/api/roles"
try:
body = {}
# "/_security/role"
response = self.conn.transport.perform_request("GET", url, body=body)
response_body = response
if response and isinstance(response_body, (dict)):
# 获取第一个角色的信息,动态生成 column_list
first_role_info = next(iter(response.values()), {})
column_list = ["role_name"] + list(first_role_info.keys())
formatted_rows = []

for role_name, role_info in response.items():
row = [role_name]
for column in first_role_info.keys():
value = role_info.get(column, None)
# 检查值的类型,如果是 list 或 dict,转换为 JSON 字符串
if isinstance(value, (list, dict)):
row.append(json.dumps(value))
else:
row.append(value)
formatted_rows.append(row)
result_set.rows = formatted_rows
result_set.column_list = column_list
except Exception as e:
raise Exception(f"执行查询时出错: {str(e)}")
return result_set

def _security_user(self, sql, query_params: QueryParamsSearch):
"""用户查询方法。"""
result_set = ResultSet(full_sql=sql)
url = "/_opendistro/_security/api/user"
try:
body = {}
# "/_security/role"
response = self.conn.transport.perform_request("GET", url, body=body)
response_body = response
if response and isinstance(response_body, (dict)):
# 获取第一个角色的信息,动态生成 column_list
first_role_info = next(iter(response.values()), {})
column_list = ["user_name"] + list(first_role_info.keys())
formatted_rows = []

for role_name, role_info in response.items():
row = [role_name]
for column in first_role_info.keys():
value = role_info.get(column, None)
# 检查值的类型,如果是 list 或 dict,转换为 JSON 字符串
if isinstance(value, (list, dict)):
row.append(json.dumps(value))
else:
row.append(value)
formatted_rows.append(row)
result_set.rows = formatted_rows
result_set.column_list = column_list
except Exception as e:
raise Exception(f"执行查询时出错: {str(e)}")
return result_set
30 changes: 25 additions & 5 deletions sql/engines/test_elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def test_get_all_databases(self, mockElasticsearch):
expected_result = [
"other",
"system",
"system_internal",
"system_kibana",
"test",
]
self.assertEqual(result.rows, expected_result)
Expand All @@ -57,9 +55,31 @@ def test_get_all_tables(self, mockElasticsearch):
result = self.engine.get_all_tables(db_name="test")
self.assertEqual(result.rows, ["test__index1", "test__index2"])

# Test system_kibana
result = self.engine.get_all_tables(db_name="system_kibana")
self.assertEqual(result.rows, [".kibana_1"])
@patch("sql.engines.elasticsearch.Elasticsearch")
def test_get_all_tables_system(self, mockElasticsearch):
"""测试获取所有表名,特定数据库 'system'"""
mock_conn = Mock()
mockElasticsearch.return_value = mock_conn

# 假设系统相关的索引(以.开头的)和特定的_cat API端点
mock_conn.indices.get_alias.return_value = {
".kibana_1": {},
".security": {},
"test__index": {},
}

result = self.engine.get_all_tables(db_name="system")

# 预期结果应包括系统相关的表名和 /_cat API 端点
expected_tables = [
"/_cat/indices/*",
"/_cat/indices/test",
"/_cat/nodes",
"/_security/role",
"/_security/user",
]

self.assertEqual(result.rows, expected_tables)

@patch("sql.engines.elasticsearch.Elasticsearch")
def test_query(self, mockElasticsearch):
Expand Down
Loading

0 comments on commit b4c4827

Please sign in to comment.