diff --git a/archery/settings.py b/archery/settings.py index de45f1dd41..949b0640fe 100644 --- a/archery/settings.py +++ b/archery/settings.py @@ -53,6 +53,7 @@ "odps", "cassandra", "doris", + "elasticsearch", ], ), ENABLED_NOTIFIERS=( @@ -101,6 +102,7 @@ "phoenix": {"path": "sql.engines.phoenix:PhoenixEngine"}, "odps": {"path": "sql.engines.odps:ODPSEngine"}, "doris": {"path": "sql.engines.doris:DorisEngine"}, + "elasticsearch": {"path": "sql.engines.elasticsearch:ElasticsearchEngine"}, } ENABLED_NOTIFIERS = env("ENABLED_NOTIFIERS") diff --git a/requirements.txt b/requirements.txt index 70a447bfab..9555e008b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,3 +44,5 @@ django-cas-ng==4.3.0 cassandra-driver httpx OpenAI +elasticsearch==8.14.0 + diff --git a/sql/engines/elasticsearch.py b/sql/engines/elasticsearch.py new file mode 100644 index 0000000000..7a08cf40f5 --- /dev/null +++ b/sql/engines/elasticsearch.py @@ -0,0 +1,409 @@ +# -*- coding: UTF-8 -*- +import logging +import re +import traceback +import simplejson as json +from . import EngineBase +from .models import ResultSet, ReviewSet, ReviewResult +from common.config import SysConfig +import logging + +from elasticsearch import Elasticsearch +from elasticsearch.exceptions import TransportError + + +logger = logging.getLogger("default") + + +class QueryParamsSearch: + def __init__( + self, + index: str, + path: str, + params: str, + method: str, + size: int, + query_body: dict = None, + ): + self.index = index + self.path = path + self.params = params + self.method = method + self.size = size + # 确保 query_body 不为 None + self.query_body = query_body if query_body is not None else {} + + +class ElasticsearchEngineBase(EngineBase): + """Elasticsearch、OpenSearch等Search父类实现""" + + def __init__(self, instance=None): + self.db_separator = "__" # 设置分隔符 + # 限制只能2种支持的子类 + self.search_name = ["Elasticsearch", "OpenSearch"] + if self.name not in self.search_name: + raise ValueError( + f"Invalid name: {self.name}. Must be one of {self.search_name}." + ) + super().__init__(instance=instance) + + def get_connection(self, db_name=None): + """返回一个conn实例""" + + def test_connection(self): + """测试实例链接是否正常""" + return self.get_all_databases() + + name: str = "SearchBase" + info: str = "SearchBase 引擎" + + def get_all_databases(self): + """获取所有“数据库”名(从索引名提取),默认提取 __ 前的部分作为数据库名""" + try: + self.get_connection() + # 获取所有的别名,没有别名就是本身。 + indices = self.conn.indices.get_alias(index=self.db_name) + database_names = set() + if self.db_name == "*": + database_names.add("system") # 系统表名使用的库名 + for index_name in indices.keys(): + if self.db_separator in index_name: + db_name = index_name.split(self.db_separator)[0] + database_names.add(db_name) + elif index_name.startswith(".kibana_"): + database_names.add("system_kibana") + elif index_name.startswith(".internal."): + database_names.add("system_internal") + database_names.add("other") # 表名没有__时,使用的库名 + database_names_sorted = sorted(database_names) + return ResultSet(rows=database_names_sorted) + except Exception as e: + logger.error(f"获取数据库时出错:{e}{traceback.format_exc()}") + raise Exception(f"获取数据库时出错: {str(e)}") + + def get_all_tables(self, db_name, **kwargs): + """根据给定的数据库名获取所有相关的表名""" + try: + self.get_connection() + indices = self.conn.indices.get_alias(index=self.db_name) + tables = set() + + db_mapping = { + "system_kibana": ".kibana_", + "system_internal": ".internal.", + "system": ".", + "other": "other", + } + # 根据分隔符分隔的库名 + if db_name not in db_mapping: + index_prefix = db_name.rstrip(self.db_separator) + self.db_separator + tables = [ + index for index in indices.keys() if index.startswith(index_prefix) + ] + else: + # 处理系统表,和other,循环db_mapping.items() 很难实现。 + for index_name in indices.keys(): + if index_name.startswith(".kibana_") | index_name.startswith( + ".kibana-" + ): + if db_name == "system_kibana": + tables.add(index_name) + continue + elif index_name.startswith(".internal."): + if db_name == "system_internal": + tables.add(index_name) + continue + elif index_name.startswith("."): + if db_name == "system": + tables.add(index_name) + continue + elif index_name.startswith(db_name): + tables.add(index_name) + continue + elif self.db_separator in index_name: + continue + else: + if db_name == "other": + tables.add(index_name) + tables_sorted = sorted(tables) + return ResultSet(rows=tables_sorted) + except Exception as e: + raise Exception(f"获取表列表时出错: {str(e)}") + + def get_all_columns_by_tb(self, db_name, tb_name, **kwargs): + """获取所有字段""" + result_set = ResultSet(full_sql=f"{tb_name}/_mapping") + try: + self.get_connection() + mapping = self.conn.indices.get_mapping(index=tb_name) + properties = ( + mapping.get(tb_name, {}).get("mappings", {}).get("properties", None) + ) + # 返回字段名 + result_set.column_list = ["column_name"] + if properties is None: + result_set.rows = [("无")] + else: + result_set.rows = list(properties.keys()) + return result_set + except Exception as e: + raise Exception(f"获取字段时出错: {str(e)}") + + def describe_table(self, db_name, tb_name, **kwargs): + """表结构""" + result_set = ResultSet(full_sql=f"{tb_name}/_mapping") + try: + self.get_connection() + mapping = self.conn.indices.get_mapping(index=tb_name) + properties = ( + mapping.get(tb_name, {}).get("mappings", {}).get("properties", None) + ) + # 创建包含字段名、类型和其他信息的列表结构 + result_set.column_list = ["column_name", "type", "fields"] + if properties is None: + result_set.rows = [("无", "无", "无")] + else: + result_set.rows = [ + ( + column, + details.get("type"), + json.dumps(details.get("fields", {})), + ) + for column, details in properties.items() + ] + return result_set + except Exception as e: + raise Exception(f"获取字段时出错: {str(e)}") + + def query_check(self, db_name=None, sql=""): + """语句检查""" + result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False} + # 使用正则表达式去除开头的空白字符和换行符 + tripped_sql = re.sub(r"^\s+", "", sql) + result["filtered_sql"] = tripped_sql + lower_sql = tripped_sql.lower() + # 检查是否以 'get' 或 'select' 开头 + if lower_sql.startswith("get ") or lower_sql.startswith("select "): + result["msg"] = "语句检查通过。" + result["bad_query"] = False + else: + result["msg"] = ( + "语句检查失败:语句必须以 'get' 或 'select' 开头。示例查询:GET /dmp__iv/_search、select * from dmp__iv limit 10;" + ) + result["bad_query"] = True + return result + + def filter_sql(self, sql="", limit_num=0): + """过滤 SQL 语句""" + return sql.strip() + + def query( + self, + db_name=None, + sql="", + limit_num=0, + close_conn=True, + parameters=None, + **kwargs, + ): + """执行查询""" + try: + result_set = ResultSet(full_sql=sql) + + # 解析查询字符串 + query_params = self.parse_es_select_query_to_query_params(sql, limit_num) + self.get_connection() + # 管理查询处理 + if query_params.path.startswith("/_cat/indices"): + # v这个参数用显示标题,需要加上。 + if "v" not in query_params.params: + query_params.params["v"] = "true" + response = self.conn.cat.indices( + index=query_params.index, params=query_params.params + ) + response_body = "" + if isinstance(response, str): + response_body = response + else: + response_body = response.body + response_data = self.parse_cat_indices_response(response_body) + # 如果有数据,设置列名 + if response_data: + result_set.column_list = list(response_data[0].keys()) + result_set.rows = [tuple(row.values()) for row in response_data] + else: + result_set.column_list = [] + result_set.rows = [] + result_set.affected_rows = 0 + else: + # 执行搜索查询 + response = self.conn.search( + index=query_params.index, + body=query_params.query_body, + params=query_params.params, + ) + + # 提取查询结果 + hits = response.get("hits", {}).get("hits", []) + # 处理查询结果,将列表和字典转换为 JSON 字符串 + rows = [] + for hit in hits: + # 获取文档 ID 和 _source 数据 + doc_id = hit.get("_id") + source_data = hit.get("_source", {}) + + # 转换需要转换为 JSON 字符串的字段 + for key, value in source_data.items(): + if isinstance(value, (list, dict)): # 如果字段是列表或字典 + source_data[key] = json.dumps(value) # 转换为 JSON 字符串 + + # 构建结果行 + row = {"_id": doc_id, **source_data} + rows.append(row) + + # 如果有结果,获取字段名作为列名 + if rows: + first_row = rows[0] + column_list = list(first_row.keys()) + else: + column_list = [] + + # 构建结果集 + result_set.rows = [tuple(row.values()) for row in rows] # 只获取值 + result_set.column_list = column_list + result_set.affected_rows = len(result_set.rows) + return result_set + except Exception as e: + raise Exception(f"执行查询时出错: {str(e)}") + + def parse_cat_indices_response(self, response_text): + """解析cat indices结果""" + # 将响应文本按行分割 + lines = response_text.strip().splitlines() + # 获取列标题 + headers = lines[0].strip().split() + # 解析每一行数据 + indices_info = [] + for line in lines[1:]: + # 按空格分割,并与标题进行配对 + values = line.strip().split(maxsplit=len(headers) - 1) + index_info = dict(zip(headers, values)) + indices_info.append(index_info) + return indices_info + + def parse_es_select_query_to_query_params( + self, search_query_str: str, limit_num: int + ) -> QueryParamsSearch: + """解析 search query 字符串为 QueryParamsSearch 对象""" + + # 解析查询字符串 + lines = search_query_str.splitlines() + method_line = lines[0].strip() + + query_body = "\n".join(lines[1:]).strip() + # 如果 query_body 为空,使用默认查询体 + if not query_body: + query_body = json.dumps({"query": {"match_all": {}}}) + + # 确保 query_body 是有效的 JSON + try: + json_body = json.loads(query_body) + except json.JSONDecodeError as json_err: + raise ValueError(f"query_body:{query_body} 无法转为Json格式。{json_err},") + + # 提取方法和路径 + method, path_with_params = method_line.split(maxsplit=1) + # 确保路径以 '/' 开头 + if not path_with_params.startswith("/"): + path_with_params = "/" + path_with_params + + # 分离路径和查询参数 + path, params_str = ( + path_with_params.split("?", 1) + if "?" in path_with_params + else (path_with_params, "") + ) + params = {} + if params_str: + for pair in params_str.split("&"): + if "=" in pair: + key, value = pair.split("=", 1) + else: + key = pair + value = "" + params[key] = value + index_pattern = "" + # 判断路径类型并提取索引模式 + if path.startswith("/_cat/indices"): + # _cat API 路径 + path_parts = path.split("/") + if len(path_parts) > 3: + index_pattern = path_parts[3] + if not index_pattern: + index_pattern = "*" + elif "/_search" in path: + # 默认情况,处理常规索引路径 + # 提取索引名称 + path_parts = path.split("/") + if len(path_parts) > 1: + index_pattern = path_parts[1] + + if not index_pattern: + raise Exception("未找到索引名称。") + + size = limit_num if limit_num > 0 else 100 + # 检查 JSON 中是否已经有 size,如果没有就设置 + if "size" not in json_body: + json_body["size"] = size + + # 构建 QueryParams 对象 + query_params = QueryParamsSearch( + index=index_pattern, + path=path_with_params, + params=params, + method=method, + size=size, + query_body=json_body, + ) + + return query_params + + +class ElasticsearchEngine(ElasticsearchEngineBase): + """Elasticsearch 引擎实现""" + + def __init__(self, instance=None): + super().__init__(instance=instance) + + name: str = "Elasticsearch" + info: str = "Elasticsearch 引擎" + + def get_connection(self, db_name=None): + if self.conn: + return self.conn + if self.instance: + scheme = "https" if self.is_ssl else "http" + hosts = [ + { + "host": self.host, + "port": self.port, + "scheme": scheme, + "use_ssl": self.is_ssl, + } + ] + http_auth = ( + (self.user, self.password) if self.user and self.password else None + ) + self.db_name = (self.db_name or "") + "*" + try: + # 创建 Elasticsearch 连接,高版本有basic_auth + self.conn = Elasticsearch( + hosts=hosts, + http_auth=http_auth, + verify_certs=True, # 需要证书验证 + ) + except Exception as e: + raise Exception(f"Elasticsearch 连接建立失败: {str(e)}") + if not self.conn: + raise Exception("Elasticsearch 连接无法建立。") + return self.conn diff --git a/sql/engines/test_elasticsearch.py b/sql/engines/test_elasticsearch.py new file mode 100644 index 0000000000..025bc6277a --- /dev/null +++ b/sql/engines/test_elasticsearch.py @@ -0,0 +1,224 @@ +import json +import unittest +from unittest.mock import patch, Mock +from elasticsearch import Elasticsearch +from elasticsearch.exceptions import TransportError +from sql.engines import ResultSet, ReviewSet +from sql.engines.elasticsearch import ElasticsearchEngine +from sql.models import Instance + + +class TestElasticsearchEngine(unittest.TestCase): + def setUp(self): + # 创建一个模拟的 instance 对象,包含必要的属性 + self.mock_instance = Instance() + self.mock_instance.host = "localhost" + self.mock_instance.port = 9200 + self.mock_instance.user = "user" + self.mock_instance.password = "pass" + self.mock_instance.is_ssl = True + + # 初始化 ElasticsearchEngine,传入模拟的 instance + self.engine = ElasticsearchEngine(instance=self.mock_instance) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_get_all_databases(self, mockElasticsearch): + mock_conn = Mock() + mock_conn.indices.get_alias.return_value = { + "test__index1": {}, + "test__index2": {}, + ".kibana_1": {}, + ".internal.index": {}, + } + mockElasticsearch.return_value = mock_conn + + result = self.engine.get_all_databases() + expected_result = [ + "other", + "system", + "system_internal", + "system_kibana", + "test", + ] + self.assertEqual(result.rows, expected_result) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_get_all_tables(self, mockElasticsearch): + mock_conn = Mock() + mock_conn.indices.get_alias.return_value = { + "test__index1": {}, + "test__index2": {}, + "other_index": {}, + ".kibana_1": {}, + } + mockElasticsearch.return_value = mock_conn + + # Test specific database + result = self.engine.get_all_tables(db_name="test") + self.assertEqual(result.rows, ["test__index1", "test__index2"]) + + # Test system_kibana + result = self.engine.get_all_tables(db_name="system_kibana") + self.assertEqual(result.rows, [".kibana_1"]) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_query(self, mockElasticsearch): + mock_conn = Mock() + mock_conn.search.return_value = { + "hits": { + "hits": [ + { + "_id": "1", + "_source": {"field1": "value1", "field2": ["val1", "val2"]}, + }, + { + "_id": "2", + "_source": { + "field1": {"subfield": "value3"}, + "field2": "value4", + }, + }, + ] + } + } + mockElasticsearch.return_value = mock_conn + + sql = "GET /test_index/_search" + result = self.engine.query(sql=sql) + expected_rows = [ + ("1", "value1", json.dumps(["val1", "val2"])), + ("2", json.dumps({"subfield": "value3"}), "value4"), + ] + self.assertEqual(result.rows, expected_rows) + self.assertEqual(result.column_list, ["_id", "field1", "field2"]) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_query_cat_indices(self, mock_elasticsearch): + """test_query_cat_indices""" + mock_conn = Mock() + mock_elasticsearch.return_value = mock_conn + mock_response = Mock() + mock_response.body = "health status index uuid pri rep docs.count docs.deleted store.size pri.store.size dataset.size\nyellow open test__index 3yyJqzgHTJqRkKwhT5Fy7w 3 1 34256 0 4.4mb 4.4mb 4.4mb\nyellow open dmp__iv fzK3nKcpRNunVr5N6gOSsw 3 1 903 0 527.1kb 527.1kb 527.1kb\n" + mock_conn.cat.indices.return_value = mock_response + + sql = "GET /_cat/indices/*?v&s=docs.count:desc" + + # 执行测试的方法 + result = self.engine.query(sql=sql) + + # 验证结果 + expected_columns = [ + "health", + "status", + "index", + "uuid", + "pri", + "rep", + "docs.count", + "docs.deleted", + "store.size", + "pri.store.size", + "dataset.size", + ] + expected_rows = [ + ( + "yellow", + "open", + "test__index", + "3yyJqzgHTJqRkKwhT5Fy7w", + "3", + "1", + "34256", + "0", + "4.4mb", + "4.4mb", + "4.4mb", + ), + ( + "yellow", + "open", + "dmp__iv", + "fzK3nKcpRNunVr5N6gOSsw", + "3", + "1", + "903", + "0", + "527.1kb", + "527.1kb", + "527.1kb", + ), + ] + self.assertEqual(result.column_list, expected_columns) + self.assertEqual(result.rows, expected_rows) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_get_all_columns_by_tb(self, mock_elasticsearch): + """测试获取表字段""" + + mock_conn = Mock() + mock_elasticsearch.return_value = mock_conn + + mock_mapping = { + "mappings": { + "properties": { + "field1": {"type": "text"}, + "field2": {"type": "keyword"}, + "field3": {"type": "integer"}, + } + } + } + + mock_conn.indices.get_mapping.return_value = {"test_table": mock_mapping} + + result = self.engine.get_all_columns_by_tb( + db_name="test_db", tb_name="test_table" + ) + + expected_columns = ["column_name"] + expected_rows = ["field1", "field2", "field3"] + + self.assertEqual(result.column_list, expected_columns) + self.assertEqual(result.rows, expected_rows) + + @patch("sql.engines.elasticsearch.Elasticsearch") + def test_describe_table(self, mock_elasticsearch): + """测试表结构""" + + mock_conn = Mock() + mock_elasticsearch.return_value = mock_conn + + mock_mapping = { + "mappings": { + "properties": { + "field1": { + "type": "text", + "fields": {"keyword": {"type": "keyword"}}, + }, + "field2": {"type": "integer"}, + "field3": {"type": "date"}, + } + } + } + mock_conn.indices.get_mapping.return_value = {"test_table": mock_mapping} + + result = self.engine.describe_table(db_name="test_db", tb_name="test_table") + + expected_columns = ["column_name", "type", "fields"] + expected_rows = [ + ("field1", "text", json.dumps({"keyword": {"type": "keyword"}})), + ("field2", "integer", "{}"), + ("field3", "date", "{}"), + ] + + # Assertions + self.assertEqual(result.column_list, expected_columns) + self.assertEqual(result.rows, expected_rows) + + def test_query_check(self): + valid_sql = "GET /test_index/_search" + result = self.engine.query_check(sql=valid_sql) + self.assertFalse(result["bad_query"]) + + invalid_sql = "PUT /test_index/_doc/1" + result = self.engine.query_check(sql=invalid_sql) + self.assertTrue(result["bad_query"]) diff --git a/sql/models.py b/sql/models.py index 4a6caf5063..f72cbdc845 100755 --- a/sql/models.py +++ b/sql/models.py @@ -134,6 +134,7 @@ class Meta: ("goinception", "goInception"), ("cassandra", "Cassandra"), ("doris", "Doris"), + ("elasticsearch", "Elasticsearch"), )