diff --git a/sql/data_dictionary.py b/sql/data_dictionary.py index 2ee48b8d5e..4a702d8075 100644 --- a/sql/data_dictionary.py +++ b/sql/data_dictionary.py @@ -29,6 +29,7 @@ def table_list(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) data = query_engine.get_group_tables_by_db(db_name=db_name) res = {"status": 0, "data": data} except Instance.DoesNotExist: @@ -50,6 +51,7 @@ def table_info(request): db_name = request.GET.get("db_name", "") tb_name = request.GET.get("tb_name", "") db_type = request.GET.get("db_type", "") + if instance_name and db_name and tb_name: data = {} try: @@ -57,6 +59,8 @@ def table_info(request): instance_name=instance_name, db_type=db_type ) query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + tb_name = query_engine.escape_string(tb_name) data["meta_data"] = query_engine.get_table_meta_data( db_name=db_name, tb_name=tb_name ) @@ -91,8 +95,6 @@ def export(request): """导出数据字典""" instance_name = request.GET.get("instance_name", "") db_name = request.GET.get("db_name", "") - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") try: instance = user_instances( @@ -104,7 +106,7 @@ def export(request): # 普通用户仅可以获取指定数据库的字典信息 if db_name: - dbs = [db_name] + dbs = [query_engine.escape_string(db_name)] # 管理员可以导出整个实例的字典信息 elif request.user.is_superuser: dbs = query_engine.get_all_databases().rows diff --git a/sql/engines/__init__.py b/sql/engines/__init__.py index a101abed13..7adf3df930 100644 --- a/sql/engines/__init__.py +++ b/sql/engines/__init__.py @@ -86,6 +86,10 @@ def info(self): """返回引擎简介""" return "Base engine" + def escape_string(self, value: str) -> str: + """参数转义""" + return value + @property def auto_backup(self): """是否支持备份""" diff --git a/sql/engines/clickhouse.py b/sql/engines/clickhouse.py index 22216c4be5..a776d7ed07 100644 --- a/sql/engines/clickhouse.py +++ b/sql/engines/clickhouse.py @@ -1,5 +1,6 @@ # -*- coding: UTF-8 -*- from clickhouse_driver import connect +from clickhouse_driver.util.escape import escape_chars_map from sql.utils.sql_utils import get_syntax_type from .models import ResultSet, ReviewResult, ReviewSet from common.utils.timer import FuncTimer @@ -49,6 +50,10 @@ def name(self): def info(self): return "ClickHouse engine" + def escape_string(self, value: str) -> str: + """字符串参数转义""" + return "'%s'" % "".join(escape_chars_map.get(c, c) for c in value) + @property def auto_backup(self): """是否支持备份""" diff --git a/sql/engines/mysql.py b/sql/engines/mysql.py index ee47fd4240..637a604b71 100644 --- a/sql/engines/mysql.py +++ b/sql/engines/mysql.py @@ -98,6 +98,10 @@ def name(self): def info(self): return "MySQL engine" + def escape_string(self, value: str) -> str: + """字符串参数转义""" + return MySQLdb.escape_string(value).decode("utf-8") + @property def auto_backup(self): """是否支持备份""" @@ -167,7 +171,7 @@ def get_all_tables(self, db_name, **kwargs): def get_group_tables_by_db(self, db_name): # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") + db_name = self.escape_string(db_name) data = {} sql = f"""SELECT TABLE_NAME, TABLE_COMMENT @@ -186,8 +190,8 @@ def get_group_tables_by_db(self, db_name): def get_table_meta_data(self, db_name, tb_name, **kwargs): """数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}""" # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") + db_name = self.escape_string(db_name) + tb_name = self.escape_string(tb_name) sql = f"""SELECT TABLE_NAME as table_name, ENGINE as engine, diff --git a/sql/instance.py b/sql/instance.py index c041f7035f..34accdc3c3 100644 --- a/sql/instance.py +++ b/sql/instance.py @@ -163,6 +163,9 @@ def param_edit(request): instance_id = request.POST.get("instance_id") variable_name = request.POST.get("variable_name") variable_value = request.POST.get("runtime_value") + # escape + variable_name = MySQLdb.escape_string(variable_name).decode("utf-8") + variable_value = MySQLdb.escape_string(variable_value).decode("utf-8") try: ins = Instance.objects.get(id=instance_id) @@ -320,12 +323,10 @@ def instance_resource(request): result = {"status": 0, "msg": "ok", "data": []} try: - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") - query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: @@ -363,10 +364,14 @@ def describe(request): db_name = request.POST.get("db_name") schema_name = request.POST.get("schema_name") tb_name = request.POST.get("tb_name") + result = {"status": 0, "msg": "ok", "data": []} try: query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) query_result = query_engine.describe_table( db_name, tb_name, schema_name=schema_name ) diff --git a/sql/instance_database.py b/sql/instance_database.py index 15d1572c35..87b51fcb4a 100644 --- a/sql/instance_database.py +++ b/sql/instance_database.py @@ -111,10 +111,8 @@ def create(request): except Users.DoesNotExist: return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []}) - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - engine = get_engine(instance=instance) + db_name = engine.escape_string(db_name) exec_result = engine.execute( db_name="information_schema", sql=f"create database {db_name};" ) diff --git a/sql/sql_optimize.py b/sql/sql_optimize.py index b147fa9c98..62a9f8a80b 100644 --- a/sql/sql_optimize.py +++ b/sql/sql_optimize.py @@ -163,8 +163,6 @@ def optimize_sqltuning(request): except Instance.DoesNotExist: result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []} return HttpResponse(json.dumps(result), content_type="application/json") - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") sql_tunning = SqlTuning( instance_name=instance_name, db_name=db_name, sqltext=sqltext @@ -235,6 +233,7 @@ def explain(request): # 执行获取执行计划语句 query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result @@ -287,6 +286,7 @@ def optimize_sqltuningadvisor(request): # 执行获取优化报告 query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict() result["data"] = sql_result diff --git a/sql/sql_tuning.py b/sql/sql_tuning.py index 973406cba9..4dac2a46cf 100644 --- a/sql/sql_tuning.py +++ b/sql/sql_tuning.py @@ -13,7 +13,7 @@ def __init__(self, instance_name, db_name, sqltext): instance = Instance.objects.get(instance_name=instance_name) query_engine = get_engine(instance=instance) self.engine = query_engine - self.db_name = db_name + self.db_name = self.engine.escape_string(db_name) self.sqltext = sqltext self.sql_variable = """ select diff --git a/sql/tests.py b/sql/tests.py index 5a3e8202c9..357a1ccea0 100644 --- a/sql/tests.py +++ b/sql/tests.py @@ -2539,7 +2539,7 @@ def test_param_edit_variable_not_config( data = { "instance_id": self.master.id, "variable_name": "1", - "variable_value": "false", + "runtime_value": "false", } r = self.client.post(path="/param/edit/", data=data) self.assertEqual( diff --git a/sql_api/api_instance.py b/sql_api/api_instance.py index 4cb50b51ba..6787ca4149 100644 --- a/sql_api/api_instance.py +++ b/sql_api/api_instance.py @@ -187,12 +187,10 @@ def post(self, request): instance = Instance.objects.get(pk=instance_id) try: - # escape - db_name = MySQLdb.escape_string(db_name).decode("utf-8") - schema_name = MySQLdb.escape_string(schema_name).decode("utf-8") - tb_name = MySQLdb.escape_string(tb_name).decode("utf-8") - query_engine = get_engine(instance=instance) + db_name = query_engine.escape_string(db_name) + schema_name = query_engine.escape_string(schema_name) + tb_name = query_engine.escape_string(tb_name) if resource_type == "database": resource = query_engine.get_all_databases() elif resource_type == "schema" and db_name: diff --git a/sql_api/api_workflow.py b/sql_api/api_workflow.py index fa702cefb5..480a704e1a 100644 --- a/sql_api/api_workflow.py +++ b/sql_api/api_workflow.py @@ -1,3 +1,4 @@ +import MySQLdb from django.contrib.auth.decorators import permission_required from django.utils.decorators import method_decorator from rest_framework import views, generics, status, serializers, permissions @@ -60,9 +61,11 @@ def post(self, request): instance = serializer.get_instance() # 交给engine进行检测 try: + db_name = request.data["db_name"] check_engine = get_engine(instance=instance) + db_name = check_engine.escape_string(db_name) check_result = check_engine.execute_check( - db_name=request.data["db_name"], sql=request.data["full_sql"].strip() + db_name=db_name, sql=request.data["full_sql"].strip() ) except Exception as e: raise serializers.ValidationError({"errors": f"{e}"})