Skip to content

Commit

Permalink
improve sql wall, support rand
Browse files Browse the repository at this point in the history
  • Loading branch information
wenshao committed Sep 2, 2013
1 parent 3afd71a commit 5fab9ea
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ public Function getFunction(String funcName) {
public void registerFunction(String funcName, Function function) {
functions.put(funcName, function);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}

public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,9 @@ public void registerFunction(String funcName, Function function) {
public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ public Function getFunction(String funcName) {
public void registerFunction(String funcName, Function function) {
functions.put(funcName, function);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}

public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ public Function getFunction(String funcName) {
public void registerFunction(String funcName, Function function) {
functions.put(funcName, function);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}

public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ public void registerFunction(String funcName, Function function) {
functions.put(funcName, function);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}

public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public interface SQLEvalVisitor extends SQLASTVisitor {
Function getFunction(String funcName);

void registerFunction(String funcName, Function function);

void unregisterFunction(String funcName);

List<Object> getParameters();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,9 @@ public void registerFunction(String funcName, Function function) {
public boolean visit(SQLIdentifierExpr x) {
return SQLEvalVisitorUtils.visit(this, x);
}

@Override
public void unregisterFunction(String funcName) {
functions.remove(funcName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,11 @@ static void registerBaseFunctions() {
public static boolean visit(SQLEvalVisitor visitor, SQLMethodInvokeExpr x) {
String methodName = x.getMethodName().toLowerCase();

Function function = functions.get(methodName);
Function function = visitor.getFunction(methodName);

if (function == null) {
function = functions.get(methodName);
}

if (function != null) {
Object result = function.eval(visitor, x);
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/com/alibaba/druid/sql/visitor/functions/Nil.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package com.alibaba.druid.sql.visitor.functions;

import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.visitor.SQLEvalVisitor;

public class Nil implements Function {

public final static Nil instance = new Nil();

@Override
public Object eval(SQLEvalVisitor visitor, SQLMethodInvokeExpr x) {
return null;
}

}
18 changes: 15 additions & 3 deletions src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMultiInsertStatement;
import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerExecStatement;
import com.alibaba.druid.sql.visitor.ExportParameterVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitorUtils;
import com.alibaba.druid.sql.visitor.functions.Nil;
import com.alibaba.druid.support.logging.Log;
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.JdbcUtils;
Expand Down Expand Up @@ -739,7 +741,7 @@ public static Object getValue(WallVisitor visitor, SQLBinaryOpExpr x) {
dbType = wallContext.getDbType();
}

return SQLEvalVisitorUtils.eval(dbType, x, Collections.emptyList(), false);
return eval(dbType, x, Collections.emptyList());
}

public static SQLExpr getFirst(SQLExpr x) {
Expand Down Expand Up @@ -1042,7 +1044,7 @@ public static Object getValue(WallVisitor visitor, SQLExpr x) {
|| x instanceof SQLInListExpr //
|| x instanceof SQLUnaryExpr //
) {
return SQLEvalVisitorUtils.eval(dbType, x, Collections.emptyList(), false);
return eval(dbType, x, Collections.emptyList());
}

if (visitor != null && (!visitor.getConfig().isCaseConditionAllow()) && x instanceof SQLCaseExpr) {
Expand All @@ -1065,11 +1067,21 @@ public static Object getValue(WallVisitor visitor, SQLExpr x) {
}
}

return SQLEvalVisitorUtils.eval(dbType, x, Collections.emptyList(), false);
return eval(dbType, x, Collections.emptyList());
}

return null;
}

public static Object eval(String dbType, SQLObject sqlObject, List<Object> parameters) {
SQLEvalVisitor visitor = SQLEvalVisitorUtils.createEvalVisitor(dbType);
visitor.setParameters(parameters);
visitor.registerFunction("rand", Nil.instance);
sqlObject.accept(visitor);

Object value = SQLEvalVisitorUtils.getValue(sqlObject);
return value;
}

public static boolean isSimpleCountTableSource(WallVisitor visitor, SQLTableSource tableSource) {
if (!(tableSource instanceof SQLSubqueryTableSource)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 1999-2011 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.druid.bvt.filter.wall;

import junit.framework.TestCase;

import org.junit.Assert;

import com.alibaba.druid.wall.WallProvider;
import com.alibaba.druid.wall.spi.MySqlWallProvider;

public class MySqlWallTest144 extends TestCase {

public void test_false() throws Exception {
WallProvider provider = new MySqlWallProvider();

String sql = "select min(id) from wx_interact where activityid=1008 group by true_name,mobile having rand()<1";
Assert.assertTrue(provider.checkValid(sql));
}
}

0 comments on commit 5fab9ea

Please sign in to comment.