Skip to content

Commit

Permalink
engine增加escape_string用于处理字符串参数转义
Browse files Browse the repository at this point in the history
  • Loading branch information
hhyo committed Mar 31, 2023
1 parent f275b56 commit c5866d4
Show file tree
Hide file tree
Showing 13 changed files with 62 additions and 24 deletions.
8 changes: 5 additions & 3 deletions sql/data_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def table_list(request):
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
data = query_engine.get_group_tables_by_db(db_name=db_name)
res = {"status": 0, "data": data}
except Instance.DoesNotExist:
Expand All @@ -50,13 +51,16 @@ def table_info(request):
db_name = request.GET.get("db_name", "")
tb_name = request.GET.get("tb_name", "")
db_type = request.GET.get("db_type", "")

if instance_name and db_name and tb_name:
data = {}
try:
instance = Instance.objects.get(
instance_name=instance_name, db_type=db_type
)
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
tb_name = query_engine.escape_string(tb_name)
data["meta_data"] = query_engine.get_table_meta_data(
db_name=db_name, tb_name=tb_name
)
Expand Down Expand Up @@ -91,8 +95,6 @@ def export(request):
"""导出数据字典"""
instance_name = request.GET.get("instance_name", "")
db_name = request.GET.get("db_name", "")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

try:
instance = user_instances(
Expand All @@ -104,7 +106,7 @@ def export(request):

# 普通用户仅可以获取指定数据库的字典信息
if db_name:
dbs = [db_name]
dbs = [query_engine.escape_string(db_name)]
# 管理员可以导出整个实例的字典信息
elif request.user.is_superuser:
dbs = query_engine.get_all_databases().rows
Expand Down
4 changes: 4 additions & 0 deletions sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def info(self):
"""返回引擎简介"""
return "Base engine"

def escape_string(self, value: str) -> str:
"""参数转义"""
return value

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down
5 changes: 5 additions & 0 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: UTF-8 -*-
from clickhouse_driver import connect
from clickhouse_driver.util.escape import escape_chars_map
from sql.utils.sql_utils import get_syntax_type
from .models import ResultSet, ReviewResult, ReviewSet
from common.utils.timer import FuncTimer
Expand Down Expand Up @@ -49,6 +50,10 @@ def name(self):
def info(self):
return "ClickHouse engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return "'%s'" % "".join(escape_chars_map.get(c, c) for c in value)

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down
13 changes: 13 additions & 0 deletions sql/engines/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ def get_connection(self, db_name=None):
self.conn = pyodbc.connect(connstr)
return self.conn

@property
def name(self):
return "MSSQL"

@property
def info(self):
return "MSSQL engine"

@property
def auto_backup(self):
"""是否支持备份"""
return True

def get_all_databases(self):
"""获取数据库列表, 返回一个ResultSet"""
sql = "SELECT name FROM master.sys.databases order by name"
Expand Down
10 changes: 7 additions & 3 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def name(self):
def info(self):
return "MySQL engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
return MySQLdb.escape_string(value).decode("utf-8")

@property
def auto_backup(self):
"""是否支持备份"""
Expand Down Expand Up @@ -167,7 +171,7 @@ def get_all_tables(self, db_name, **kwargs):

def get_group_tables_by_db(self, db_name):
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
db_name = self.escape_string(db_name)
data = {}
sql = f"""SELECT TABLE_NAME,
TABLE_COMMENT
Expand All @@ -186,8 +190,8 @@ def get_group_tables_by_db(self, db_name):
def get_table_meta_data(self, db_name, tb_name, **kwargs):
"""数据字典页面使用:获取表格的元信息,返回一个dict{column_list: [], rows: []}"""
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")
db_name = self.escape_string(db_name)
tb_name = self.escape_string(tb_name)
sql = f"""SELECT
TABLE_NAME as table_name,
ENGINE as engine,
Expand Down
6 changes: 6 additions & 0 deletions sql/engines/pgsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import traceback
import sqlparse

from psycopg2.extensions import quote_ident
from common.config import SysConfig
from common.utils.timer import FuncTimer
from sql.utils.sql_utils import get_syntax_type
Expand Down Expand Up @@ -48,6 +49,11 @@ def name(self):
def info(self):
return "PgSQL engine"

def escape_string(self, value: str) -> str:
"""字符串参数转义"""
self.conn = self.get_connection()
return quote_ident(value, self.conn)

def get_all_databases(self):
"""
获取数据库列表
Expand Down
15 changes: 10 additions & 5 deletions sql/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def param_edit(request):
instance_id = request.POST.get("instance_id")
variable_name = request.POST.get("variable_name")
variable_value = request.POST.get("runtime_value")
# escape
variable_name = MySQLdb.escape_string(variable_name).decode("utf-8")
variable_value = MySQLdb.escape_string(variable_value).decode("utf-8")

try:
ins = Instance.objects.get(id=instance_id)
Expand Down Expand Up @@ -320,12 +323,10 @@ def instance_resource(request):
result = {"status": 0, "msg": "ok", "data": []}

try:
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")

query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
if resource_type == "database":
resource = query_engine.get_all_databases()
elif resource_type == "schema" and db_name:
Expand Down Expand Up @@ -363,10 +364,14 @@ def describe(request):
db_name = request.POST.get("db_name")
schema_name = request.POST.get("schema_name")
tb_name = request.POST.get("tb_name")

result = {"status": 0, "msg": "ok", "data": []}

try:
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
query_result = query_engine.describe_table(
db_name, tb_name, schema_name=schema_name
)
Expand Down
4 changes: 1 addition & 3 deletions sql/instance_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,8 @@ def create(request):
except Users.DoesNotExist:
return JsonResponse({"status": 1, "msg": "负责人不存在", "data": []})

# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

engine = get_engine(instance=instance)
db_name = engine.escape_string(db_name)
exec_result = engine.execute(
db_name="information_schema", sql=f"create database {db_name};"
)
Expand Down
4 changes: 2 additions & 2 deletions sql/sql_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ def optimize_sqltuning(request):
except Instance.DoesNotExist:
result = {"status": 1, "msg": "你所在组未关联该实例!", "data": []}
return HttpResponse(json.dumps(result), content_type="application/json")
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")

sql_tunning = SqlTuning(
instance_name=instance_name, db_name=db_name, sqltext=sqltext
Expand Down Expand Up @@ -235,6 +233,7 @@ def explain(request):

# 执行获取执行计划语句
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
sql_result = query_engine.query(str(db_name), sql_content).to_sep_dict()
result["data"] = sql_result

Expand Down Expand Up @@ -287,6 +286,7 @@ def optimize_sqltuningadvisor(request):

# 执行获取优化报告
query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
sql_result = query_engine.sqltuningadvisor(str(db_name), sql_content).to_sep_dict()
result["data"] = sql_result

Expand Down
2 changes: 1 addition & 1 deletion sql/sql_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, instance_name, db_name, sqltext):
instance = Instance.objects.get(instance_name=instance_name)
query_engine = get_engine(instance=instance)
self.engine = query_engine
self.db_name = db_name
self.db_name = self.engine.escape_string(db_name)
self.sqltext = sqltext
self.sql_variable = """
select
Expand Down
2 changes: 1 addition & 1 deletion sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,7 +2539,7 @@ def test_param_edit_variable_not_config(
data = {
"instance_id": self.master.id,
"variable_name": "1",
"variable_value": "false",
"runtime_value": "false",
}
r = self.client.post(path="/param/edit/", data=data)
self.assertEqual(
Expand Down
8 changes: 3 additions & 5 deletions sql_api/api_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,10 @@ def post(self, request):
instance = Instance.objects.get(pk=instance_id)

try:
# escape
db_name = MySQLdb.escape_string(db_name).decode("utf-8")
schema_name = MySQLdb.escape_string(schema_name).decode("utf-8")
tb_name = MySQLdb.escape_string(tb_name).decode("utf-8")

query_engine = get_engine(instance=instance)
db_name = query_engine.escape_string(db_name)
schema_name = query_engine.escape_string(schema_name)
tb_name = query_engine.escape_string(tb_name)
if resource_type == "database":
resource = query_engine.get_all_databases()
elif resource_type == "schema" and db_name:
Expand Down
5 changes: 4 additions & 1 deletion sql_api/api_workflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import MySQLdb
from django.contrib.auth.decorators import permission_required
from django.utils.decorators import method_decorator
from rest_framework import views, generics, status, serializers, permissions
Expand Down Expand Up @@ -60,9 +61,11 @@ def post(self, request):
instance = serializer.get_instance()
# 交给engine进行检测
try:
db_name = request.data["db_name"]
check_engine = get_engine(instance=instance)
db_name = check_engine.escape_string(db_name)
check_result = check_engine.execute_check(
db_name=request.data["db_name"], sql=request.data["full_sql"].strip()
db_name=db_name, sql=request.data["full_sql"].strip()
)
except Exception as e:
raise serializers.ValidationError({"errors": f"{e}"})
Expand Down

0 comments on commit c5866d4

Please sign in to comment.