Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge from edp963/davinci dev-0.3 #3

Merged
merged 16 commits into from
Dec 5, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix: extract sql with fragment to optimize sql
  • Loading branch information
RichardShan committed Dec 3, 2019
commit bfc17f9dcc7f3f970b499db19984308280aaebab
4 changes: 3 additions & 1 deletion server/src/main/java/edp/core/consts/Consts.java
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,9 @@ public class Consts {
public static final String REG_SENSITIVE_SQL = "drop\\s|alter\\s|grant\\s|insert\\s|replace\\s|delete\\s|truncate\\s|update\\s|remove\\s";
public static final Pattern PATTERN_SENSITIVE_SQL = Pattern.compile(REG_SENSITIVE_SQL);

private static final String REG_WITH_SQL_FRAGMENT = "((?i)WITH[\\s\\S]+(?i)AS\\s*\\([\\s\\S]+\\))\\s*(?i)SELECT";
public static final Pattern WITH_SQL_FRAGMENT = Pattern.compile(REG_WITH_SQL_FRAGMENT);

/**
* 匹配多行sql注解正则
*/
Expand Down Expand Up @@ -196,5 +199,4 @@ public class Consts {
public static final String JDBC_DATASOURCE_DEFAULT_VERSION = "Default";

public static final String PATH_EXT_FORMATER = "ext" + File.separator + "%s" + File.separator + "%s" + File.separator;

}
53 changes: 35 additions & 18 deletions server/src/main/java/edp/davinci/core/utils/SqlParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.stringtemplate.v4.ST;

Expand All @@ -59,15 +58,17 @@ public class SqlParseUtils {

private static final String OR = "or";

@Autowired
private static final String QUERY_WHERE_TRUE = "1=1";
private static final String QUERY_WHERE_FALSE = "1=0";

private DacChannelUtil dacChannelUtil;

/**
* 解析sql
*
* @param sqlStr view sql 模版
* @param variables view 变量
* @param sqlTempDelimiter ST 模板界定符
* @param sqlStr view sql 模版
* @param variables view 变量
* @param sqlTempDelimiter ST 模板界定符
* @return
*/
public SqlEntity parseSql(String sqlStr, List<SqlVariable> variables, String sqlTempDelimiter) throws ServerException {
Expand Down Expand Up @@ -157,10 +158,10 @@ public List<String> getAuthVarValue(SqlVariable variable, String email) {
/**
* 替换参数
*
* @param sql sql 模板
* @param queryParamMap 普通查询变量
* @param authParamMap 权限变量
* @param sqlTempDelimiter ST 界定符
* @param sql sql 模板
* @param queryParamMap 普通查询变量
* @param authParamMap 权限变量
* @param sqlTempDelimiter ST 界定符
* @return
*/
public String replaceParams(String sql, Map<String, Object> queryParamMap, Map<String, List<String>> authParamMap, String sqlTempDelimiter) {
Expand Down Expand Up @@ -201,11 +202,11 @@ public String replaceParams(String sql, Map<String, Object> queryParamMap, Map<S

ST st = new ST(sql, delimiter, delimiter);
if (!CollectionUtils.isEmpty(authParamMap) && !CollectionUtils.isEmpty(expSet)) {
authParamMap.forEach((k, v) ->{
authParamMap.forEach((k, v) -> {
List values = authParamMap.get(k);
if(CollectionUtils.isEmpty(values) || (values.size()==1 && values.get(0).toString().contains(Constants.NO_AUTH_PERMISSION))){
if (CollectionUtils.isEmpty(values) || (values.size() == 1 && values.get(0).toString().contains(Constants.NO_AUTH_PERMISSION))) {
st.add(k, false);
}else{
} else {
st.add(k, true);
}
});
Expand Down Expand Up @@ -240,7 +241,7 @@ public List<String> getSqls(String sql, boolean isQuery) {
if (split.length > 0) {
list = new ArrayList<>();
for (String sqlStr : split) {
sqlStr = sqlStr.trim();
sqlStr = rebuildSqlWithFragment(sqlStr.trim());
boolean select = sqlStr.toLowerCase().startsWith(SELECT) || sqlStr.toLowerCase().startsWith(WITH);
if (isQuery) {
if (select) {
Expand All @@ -256,6 +257,22 @@ public List<String> getSqls(String sql, boolean isQuery) {
return list;
}

private static String rebuildSqlWithFragment(String sql) {
if (!sql.toLowerCase().startsWith(WITH)) {
Matcher matcher = WITH_SQL_FRAGMENT.matcher(sql);
if (matcher.find()) {
String withFragment = matcher.group();
if (withFragment.length() > 6) {
int lastSelectIndex = withFragment.length() - 6;
sql = sql.replace(withFragment, withFragment.substring(lastSelectIndex));
withFragment = withFragment.substring(0, lastSelectIndex);
}
sql = withFragment + SPACE + sql;
sql = sql.replaceAll(SPACE + "{2,}", SPACE);
}
}
return sql;
}

private static Map<String, String> getParsedExpression(Set<String> expSet, Map<String, List<String>> authParamMap, char sqlTempDelimiter) {
Iterator<String> iterator = expSet.iterator();
Expand All @@ -274,7 +291,7 @@ private static Map<String, String> getParsedExpression(Set<String> expSet, Map<S
private static String getAuthVarExpression(String srcExpression, Map<String, List<String>> authParamMap, char sqlTempDelimiter) throws Exception {

if (null == authParamMap) {
return "1=1";
return QUERY_WHERE_TRUE;
}

String originExpression = "";
Expand Down Expand Up @@ -320,7 +337,7 @@ private static String getAuthVarExpression(String srcExpression, Map<String, Lis
String v = list.get(0);
if (!StringUtils.isEmpty(v)) {
if (v.equals(NO_AUTH_PERMISSION)) {
return "1=0";
return QUERY_WHERE_FALSE;
} else {
if (sqlOperator == SqlOperatorEnum.IN) {
expBuilder
Expand All @@ -342,7 +359,7 @@ private static String getAuthVarExpression(String srcExpression, Map<String, Lis
}

} else {
return "1=1";
return QUERY_WHERE_TRUE;
}
} else {
List<String> collect = list.stream().filter(s -> !s.contains(NO_AUTH_PERMISSION)).collect(Collectors.toList());
Expand Down Expand Up @@ -379,7 +396,7 @@ private static String getAuthVarExpression(String srcExpression, Map<String, Lis
}
return expBuilder.toString();
} else {
return "1=1";
return QUERY_WHERE_TRUE;
}
} else {
Set<String> keySet = authParamMap.keySet();
Expand All @@ -396,7 +413,7 @@ private static String getAuthVarExpression(String srcExpression, Map<String, Lis
}
return String.join(EMPTY, left, SPACE, sqlOperator.getValue(), SPACE, v);
} else {
return "1=0";
return QUERY_WHERE_FALSE;
}
}
}
Expand Down