11import io
22import json
33import os
4- import re
54import zipfile
65from datetime import datetime
76from sqlalchemy .ext .asyncio import AsyncSession
7+ from sqlglot import parse as sqlglot_parse
8+ from sqlglot .expressions import Add , Alter , Create , Delete , Drop , Expression , Insert , Table , TruncateTable , Update
89from typing import List
910from config .constant import GenConstant
10- from config .env import GenConfig
11+ from config .env import DataBaseConfig , GenConfig
1112from exceptions .exception import ServiceException
1213from module_admin .entity .vo .common_vo import CrudResponseModel
1314from module_admin .entity .vo .user_vo import CurrentUserModel
@@ -197,10 +198,11 @@ async def create_table_services(cls, query_db: AsyncSession, sql: str, current_u
197198 :param current_user: 当前用户信息对象
198199 :return: 创建表结构结果
199200 """
200- if cls .__is_valid_create_table (sql ):
201+ sql_statements = sqlglot_parse (sql , dialect = DataBaseConfig .sqlglot_parse_dialect )
202+ if cls .__is_valid_create_table (sql_statements ):
201203 try :
202- table_names = re . findall ( r'create\s+table\s+(\w+)' , sql , re . IGNORECASE )
203- await GenTableDao .create_table_by_sql_dao (query_db , sql )
204+ table_names = cls . __get_table_names ( sql_statements )
205+ await GenTableDao .create_table_by_sql_dao (query_db , sql_statements )
204206 gen_table_list = await cls .get_gen_db_table_list_by_name_services (query_db , table_names )
205207 await cls .import_gen_table_services (query_db , gen_table_list , current_user )
206208
@@ -211,22 +213,39 @@ async def create_table_services(cls, query_db: AsyncSession, sql: str, current_u
211213 raise ServiceException (message = '建表语句不合法' )
212214
213215 @classmethod
214- def __is_valid_create_table (cls , sql : str ):
216+ def __is_valid_create_table (cls , sql_statements : List [ Expression ] ):
215217 """
216218 校验sql语句是否为合法的建表语句
217219
218- :param sql: sql语句
220+ :param sql_statements: sql语句的ast列表
219221 :return: 校验结果
220222 """
221- create_table_pattern = r'^\s*CREATE\s+TABLE\s+'
222- if not re .search (create_table_pattern , sql , re .IGNORECASE ):
223+ validate_create = [isinstance (sql_statement , Create ) for sql_statement in sql_statements ]
224+ validate_forbidden_keywords = [
225+ isinstance (
226+ sql_statement ,
227+ (Add , Alter , Delete , Drop , Insert , TruncateTable , Update ),
228+ )
229+ for sql_statement in sql_statements
230+ ]
231+ if not any (validate_create ) or any (validate_forbidden_keywords ):
223232 return False
224- forbidden_keywords = ['INSERT' , 'UPDATE' , 'DELETE' , 'DROP' , 'ALTER' , 'TRUNCATE' ]
225- for keyword in forbidden_keywords :
226- if re .search (rf'\b{ keyword } \b' , sql , re .IGNORECASE ):
227- return False
228233 return True
229234
235+ @classmethod
236+ def __get_table_names (cls , sql_statements : List [Expression ]):
237+ """
238+ 获取sql语句中所有的建表表名
239+
240+ :param sql_statements: sql语句的ast列表
241+ :return: 建表表名列表
242+ """
243+ table_names = []
244+ for sql_statement in sql_statements :
245+ if isinstance (sql_statement , Create ):
246+ table_names .append (sql_statement .find (Table ).name )
247+ return table_names
248+
230249 @classmethod
231250 async def preview_code_services (cls , query_db : AsyncSession , table_id : int ):
232251 """
0 commit comments