Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add parameters in clickhouse and mysql engine to avoid sql injection #2062

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sql/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,15 @@ def filter_sql(self, sql="", limit_num=0):
"""给查询语句增加结果级限制或者改写语句, 返回修改后的语句"""
return sql.strip()

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs
):
"""实际查询 返回一个ResultSet"""
return ResultSet()

Expand Down
29 changes: 19 additions & 10 deletions sql/engines/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ def server_version(self):

def get_table_engine(self, tb_name):
"""获取某个table的engine type"""
[database, name] = tb_name.split(".")
sql = f"""select engine
from system.tables
where database='{tb_name.split('.')[0]}'
and name='{tb_name.split('.')[1]}'"""
query_result = self.query(sql=sql)
where database=%s
and name=%s"""
query_result = self.query(sql=sql, parameters=(database, name))
if query_result.rows:
result = {"status": 1, "engine": query_result.rows[0][0]}
else:
Expand Down Expand Up @@ -104,30 +105,38 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
from
system.columns
where
database = '{db_name}'
and table = '{tb_name}';"""
result = self.query(db_name=db_name, sql=sql)
database = %s
and table = %s;"""
result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name))
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
sql = f"show create table %s;"
result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,))

result.rows[0] = (tb_name,) + (
result.rows[0][0].replace("(", "(\n ").replace(",", ",\n "),
)
return result

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
cursor.execute(sql)
cursor.execute(sql, parameters=parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down
64 changes: 38 additions & 26 deletions sql/engines/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def get_group_tables_by_db(self, db_name):
FROM
information_schema.TABLES
WHERE
TABLE_SCHEMA='{db_name}';"""
result = self.query(db_name=db_name, sql=sql)
TABLE_SCHEMA=%s;"""
result = self.query(db_name=db_name, sql=sql, parameters=(db_name,))
for row in result.rows:
table_name, table_cmt = row[0], row[1]
if table_name[0] not in data:
Expand Down Expand Up @@ -208,9 +208,9 @@ def get_table_meta_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.TABLES
WHERE
TABLE_SCHEMA='{db_name}'
AND TABLE_NAME='{tb_name}'"""
_meta_data = self.query(db_name, sql)
TABLE_SCHEMA=%s
AND TABLE_NAME=%s"""
_meta_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _meta_data.column_list, "rows": _meta_data.rows[0]}

def get_table_desc_data(self, db_name, tb_name, **kwargs):
Expand All @@ -227,10 +227,10 @@ def get_table_desc_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.COLUMNS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}'
TABLE_SCHEMA = %s
AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION;"""
_desc_data = self.query(db_name, sql)
_desc_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _desc_data.column_list, "rows": _desc_data.rows}

def get_table_index_data(self, db_name, tb_name, **kwargs):
Expand All @@ -247,18 +247,19 @@ def get_table_index_data(self, db_name, tb_name, **kwargs):
FROM
information_schema.STATISTICS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}';"""
_index_data = self.query(db_name, sql)
TABLE_SCHEMA = %s
AND TABLE_NAME = %s;"""
_index_data = self.query(db_name, sql, parameters=(db_name, tb_name))
return {"column_list": _index_data.column_list, "rows": _index_data.rows}

def get_tables_metas_data(self, db_name, **kwargs):
"""获取数据库所有表格信息,用作数据字典导出接口"""
sql_tbs = (
f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='{db_name}';"
)
sql_tbs = f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA=%s;"
tbs = self.query(
sql=sql_tbs, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
sql=sql_tbs,
parameters=(db_name,),
cursorclass=MySQLdb.cursors.DictCursor,
close_conn=False,
).rows
table_metas = []
for tb in tbs:
Expand All @@ -275,9 +276,12 @@ def get_tables_metas_data(self, db_name, **kwargs):
_meta["ENGINE_KEYS"] = engine_keys
_meta["TABLE_INFO"] = tb
sql_cols = f"""SELECT * FROM INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA='{tb['TABLE_SCHEMA']}' AND TABLE_NAME='{tb['TABLE_NAME']}';"""
WHERE TABLE_SCHEMA=%s AND TABLE_NAME=%s;"""
_meta["COLUMNS"] = self.query(
sql=sql_cols, cursorclass=MySQLdb.cursors.DictCursor, close_conn=False
sql=sql_cols,
parameters=(tb["TABLE_SCHEMA"], tb["TABLE_NAME"]),
cursorclass=MySQLdb.cursors.DictCursor,
close_conn=False,
).rows
table_metas.append(_meta)
return table_metas
Expand All @@ -295,18 +299,18 @@ def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
FROM
information_schema.COLUMNS
WHERE
TABLE_SCHEMA = '{db_name}'
AND TABLE_NAME = '{tb_name}'
TABLE_SCHEMA = %s
AND TABLE_NAME = %s
ORDER BY ORDINAL_POSITION;"""
result = self.query(db_name=db_name, sql=sql)
result = self.query(db_name=db_name, sql=sql, parameters=(db_name, tb_name))
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result

def describe_table(self, db_name, tb_name, **kwargs):
"""return ResultSet 类似查询"""
sql = f"show create table `{tb_name}`;"
result = self.query(db_name=db_name, sql=sql)
sql = f"show create table %s;"
result = self.query(db_name=db_name, sql=sql, parameters=(tb_name,))
return result

@staticmethod
Expand All @@ -325,7 +329,15 @@ def result_set_binary_as_hex(result_set):
result_set.rows = tuple(new_rows)
return result_set

def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
def query(
self,
db_name=None,
sql="",
parameters=None,
limit_num=0,
close_conn=True,
**kwargs,
):
"""返回 ResultSet"""
result_set = ResultSet(full_sql=sql)
max_execution_time = kwargs.get("max_execution_time", 0)
Expand All @@ -338,7 +350,7 @@ def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
cursor.execute(f"set session max_execution_time={max_execution_time};")
except MySQLdb.OperationalError:
pass
effect_row = cursor.execute(sql)
effect_row = cursor.execute(sql, args=parameters)
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
Expand Down Expand Up @@ -518,14 +530,14 @@ def execute_workflow(self, workflow):
# inception执行
return self.inc_engine.execute(workflow)

def execute(self, db_name=None, sql="", close_conn=True):
def execute(self, db_name=None, sql="", parameters=None, close_conn=True):
"""原生执行语句"""
result = ResultSet(full_sql=sql)
conn = self.get_connection(db_name=db_name)
try:
cursor = conn.cursor()
for statement in sqlparse.split(sql):
cursor.execute(statement)
cursor.execute(statement, args=parameters)
conn.commit()
cursor.close()
except Exception as e:
Expand Down