diff --git a/pom.xml b/pom.xml
index e039c3ac0c..6723aa5873 100644
--- a/pom.xml
+++ b/pom.xml
@@ -11,7 +11,7 @@
com.alibaba
druid
- 1.0.1
+ 1.0.2
jar
druid
diff --git a/src/main/java/com/alibaba/druid/VERSION.java b/src/main/java/com/alibaba/druid/VERSION.java
index cac253bb26..f08b14b2a0 100644
--- a/src/main/java/com/alibaba/druid/VERSION.java
+++ b/src/main/java/com/alibaba/druid/VERSION.java
@@ -19,7 +19,7 @@ public final class VERSION {
public final static int MajorVersion = 1;
public final static int MinorVersion = 0;
- public final static int RevisionVersion = 1;
+ public final static int RevisionVersion = 2;
public static String getVersionNumber() {
return VERSION.MajorVersion + "." + VERSION.MinorVersion + "." + VERSION.RevisionVersion;
diff --git a/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java b/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java
index 0964b714b3..3e1b126513 100644
--- a/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java
+++ b/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java
@@ -61,9 +61,9 @@
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.DruidPasswordCallback;
import com.alibaba.druid.util.Histogram;
-import com.alibaba.druid.util.Utils;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.StringUtils;
+import com.alibaba.druid.util.Utils;
/**
* @author wenshao
@@ -1314,6 +1314,7 @@ public void setClearFiltersEnable(boolean clearFiltersEnable) {
protected final AtomicLong statementIdSeed = new AtomicLong(20000);
protected final AtomicLong resultSetIdSeed = new AtomicLong(50000);
protected final AtomicLong transactionIdSeed = new AtomicLong(60000);
+ protected final AtomicLong metaDataIdSeed = new AtomicLong(80000);
public long createConnectionId() {
return connectionIdSeed.incrementAndGet();
@@ -1323,6 +1324,10 @@ public long createStatementId() {
return statementIdSeed.getAndIncrement();
}
+ public long createMetaDataId() {
+ return metaDataIdSeed.getAndIncrement();
+ }
+
public long createResultSetId() {
return resultSetIdSeed.getAndIncrement();
}
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java
index 960763c774..9c9c6c085c 100644
--- a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java
@@ -47,6 +47,8 @@ public interface DataSourceProxy {
long createResultSetId();
+ long createMetaDataId();
+
long createTransactionId();
Properties getConnectProperties();
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java
index 041f998f2b..01d1d20021 100644
--- a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java
@@ -35,8 +35,8 @@
import com.alibaba.druid.filter.FilterChainImpl;
import com.alibaba.druid.stat.JdbcDataSourceStat;
import com.alibaba.druid.stat.JdbcStatManager;
-import com.alibaba.druid.util.Utils;
import com.alibaba.druid.util.JdbcUtils;
+import com.alibaba.druid.util.Utils;
/**
* @author wenshao
@@ -58,6 +58,7 @@ public class DataSourceProxyImpl implements DataSourceProxy, DataSourceProxyImpl
private final AtomicLong connectionIdSeed = new AtomicLong(10000);
private final AtomicLong statementIdSeed = new AtomicLong(20000);
private final AtomicLong resultSetIdSeed = new AtomicLong(50000);
+ private final AtomicLong metaDataIdSeed = new AtomicLong(100000);
private final AtomicLong transactionIdSeed = new AtomicLong(0);
private final JdbcDataSourceStat dataSourceStat;
@@ -364,6 +365,10 @@ public long createResultSetId() {
return resultSetIdSeed.getAndIncrement();
}
+ public long createMetaDataId() {
+ return metaDataIdSeed.getAndIncrement();
+ }
+
@Override
public long createTransactionId() {
return transactionIdSeed.getAndIncrement();
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java
new file mode 100644
index 0000000000..df2be10011
--- /dev/null
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java
@@ -0,0 +1,28 @@
+/*
+ * Copyright 1999-2011 Alibaba Group Holding Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.druid.proxy.jdbc;
+
+import java.sql.ResultSetMetaData;
+
+/**
+ * @author kiki
+ */
+public interface ResultSetMetaDataProxy extends ResultSetMetaData, WrapperProxy {
+
+ ResultSetMetaData getResultSetMetaDataRaw();
+
+}
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java
new file mode 100644
index 0000000000..dac5320593
--- /dev/null
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java
@@ -0,0 +1,154 @@
+/*
+ * Copyright 1999-2011 Alibaba Group Holding Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.alibaba.druid.proxy.jdbc;
+
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+
+import com.alibaba.druid.filter.FilterChain;
+
+/**
+ * @author kiki
+ */
+public class ResultSetMetaDataProxyImpl extends WrapperProxyImpl implements ResultSetMetaDataProxy {
+
+ private final ResultSetMetaData metaData;
+
+ private final ResultSetProxy resultSetProxy;
+
+ public ResultSetMetaDataProxyImpl(ResultSetMetaData metaData, long id, ResultSetProxy resultSetProxy){
+ super(metaData, id);
+ this.metaData = metaData;
+ this.resultSetProxy = resultSetProxy;
+ }
+
+ @Override
+ public int getColumnCount() throws SQLException {
+ return metaData.getColumnCount() - resultSetProxy.getHiddenColumnCount();
+ }
+
+ @Override
+ public boolean isAutoIncrement(int column) throws SQLException {
+ return metaData.isAutoIncrement(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isCaseSensitive(int column) throws SQLException {
+ return metaData.isCaseSensitive(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isSearchable(int column) throws SQLException {
+ return metaData.isSearchable(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isCurrency(int column) throws SQLException {
+ return metaData.isCurrency(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public int isNullable(int column) throws SQLException {
+ return metaData.isNullable(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isSigned(int column) throws SQLException {
+ return metaData.isSigned(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public int getColumnDisplaySize(int column) throws SQLException {
+ return metaData.getColumnDisplaySize(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getColumnLabel(int column) throws SQLException {
+ return metaData.getColumnLabel(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getColumnName(int column) throws SQLException {
+ return metaData.getColumnName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getSchemaName(int column) throws SQLException {
+ return metaData.getSchemaName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public int getPrecision(int column) throws SQLException {
+ return metaData.getPrecision(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public int getScale(int column) throws SQLException {
+ return metaData.getScale(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getTableName(int column) throws SQLException {
+ return metaData.getTableName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getCatalogName(int column) throws SQLException {
+ return metaData.getCatalogName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public int getColumnType(int column) throws SQLException {
+ return metaData.getColumnType(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getColumnTypeName(int column) throws SQLException {
+ return metaData.getColumnTypeName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isReadOnly(int column) throws SQLException {
+ return metaData.isReadOnly(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isWritable(int column) throws SQLException {
+ return metaData.isWritable(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public boolean isDefinitelyWritable(int column) throws SQLException {
+ return metaData.isDefinitelyWritable(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public String getColumnClassName(int column) throws SQLException {
+ return metaData.getColumnClassName(resultSetProxy.getPhysicalColumn(column));
+ }
+
+ @Override
+ public ResultSetMetaData getResultSetMetaDataRaw() {
+ return metaData;
+ }
+
+ @Override
+ public FilterChain createChain() {
+ return null;
+ }
+
+}
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java
index 4b4e327d8d..19881dc44b 100644
--- a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java
@@ -16,6 +16,8 @@
package com.alibaba.druid.proxy.jdbc;
import java.sql.ResultSet;
+import java.util.List;
+import java.util.Map;
import com.alibaba.druid.stat.JdbcSqlStat;
@@ -29,7 +31,7 @@ public interface ResultSetProxy extends ResultSet, WrapperProxy {
StatementProxy getStatementProxy();
String getSql();
-
+
JdbcSqlStat getSqlStat();
int getCursorIndex();
@@ -41,22 +43,35 @@ public interface ResultSetProxy extends ResultSet, WrapperProxy {
void setConstructNano(long constructNano);
void setConstructNano();
-
+
int getCloseCount();
-
+
void addReadStringLength(int length);
-
+
long getReadStringLength();
-
+
void addReadBytesLength(int length);
-
+
long getReadBytesLength();
-
+
void incrementOpenInputStreamCount();
-
+
int getOpenInputStreamCount();
-
+
void incrementOpenReaderCount();
-
+
int getOpenReaderCount();
+
+ int getPhysicalColumn(int logicColumn);
+
+ List getHiddenColumns();
+
+ int getHiddenColumnCount();
+
+ void setLogicColumnMap(Map logicColumnMap);
+
+ void setPhysicalColumnMap(Map physicalColumnMap);
+
+ void setHiddenColumns(List hiddenColumns);
+
}
diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java
index 295e56d634..91d58af6fb 100644
--- a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java
+++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java
@@ -36,6 +36,7 @@
import java.sql.Time;
import java.sql.Timestamp;
import java.util.Calendar;
+import java.util.List;
import java.util.Map;
import com.alibaba.druid.filter.FilterChainImpl;
@@ -46,23 +47,27 @@
*/
public class ResultSetProxyImpl extends WrapperProxyImpl implements ResultSetProxy {
- private final ResultSet resultSet;
- private final StatementProxy statement;
- private final String sql;
+ private final ResultSet resultSet;
+ private final StatementProxy statement;
+ private final String sql;
- protected int cursorIndex = 0;
- protected int fetchRowCount = 0;
- protected long constructNano;
- protected final JdbcSqlStat sqlStat;
- private int closeCount = 0;
+ protected int cursorIndex = 0;
+ protected int fetchRowCount = 0;
+ protected long constructNano;
+ protected final JdbcSqlStat sqlStat;
+ private int closeCount = 0;
- private long readStringLength = 0;
- private long readBytesLength = 0;
+ private long readStringLength = 0;
+ private long readBytesLength = 0;
- private int openInputStreamCount = 0;
- private int openReaderCount = 0;
+ private int openInputStreamCount = 0;
+ private int openReaderCount = 0;
- private FilterChainImpl filterChain = null;
+ private Map logicColumnMap = null;
+ private Map physicalColumnMap = null;
+ private List hiddenColumns = null;
+
+ private FilterChainImpl filterChain = null;
public ResultSetProxyImpl(StatementProxy statement, ResultSet resultSet, long id, String sql){
super(resultSet, id);
@@ -117,10 +122,10 @@ public FilterChainImpl createChain() {
} else {
this.filterChain = null;
}
-
+
return chain;
}
-
+
public void recycleFilterChain(FilterChainImpl chain) {
chain.reset();
this.filterChain = chain;
@@ -180,9 +185,10 @@ public void deleteRow() throws SQLException {
@Override
public int findColumn(String columnLabel) throws SQLException {
FilterChainImpl chain = createChain();
- int value = chain.resultSet_findColumn(this, columnLabel);
+ int physicalColumn = chain.resultSet_findColumn(this, columnLabel);
recycleFilterChain(chain);
- return value;
+
+ return getLogicColumn(physicalColumn);
}
@Override
@@ -196,7 +202,7 @@ public boolean first() throws SQLException {
@Override
public Array getArray(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Array value = chain.resultSet_getArray(this, columnIndex);
+ Array value = chain.resultSet_getArray(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -212,7 +218,7 @@ public Array getArray(String columnLabel) throws SQLException {
@Override
public InputStream getAsciiStream(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- InputStream value = chain.resultSet_getAsciiStream(this, columnIndex);
+ InputStream value = chain.resultSet_getAsciiStream(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -228,7 +234,7 @@ public InputStream getAsciiStream(String columnLabel) throws SQLException {
@Override
public BigDecimal getBigDecimal(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- BigDecimal value = chain.resultSet_getBigDecimal(this, columnIndex);
+ BigDecimal value = chain.resultSet_getBigDecimal(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -244,7 +250,7 @@ public BigDecimal getBigDecimal(String columnLabel) throws SQLException {
@Override
public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException {
FilterChainImpl chain = createChain();
- BigDecimal value = chain.resultSet_getBigDecimal(this, columnIndex, scale);
+ BigDecimal value = chain.resultSet_getBigDecimal(this, getPhysicalColumn(columnIndex), scale);
recycleFilterChain(chain);
return value;
}
@@ -260,7 +266,7 @@ public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLExcepti
@Override
public InputStream getBinaryStream(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- InputStream value = chain.resultSet_getBinaryStream(this, columnIndex);
+ InputStream value = chain.resultSet_getBinaryStream(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -276,7 +282,7 @@ public InputStream getBinaryStream(String columnLabel) throws SQLException {
@Override
public Blob getBlob(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Blob value = chain.resultSet_getBlob(this, columnIndex);
+ Blob value = chain.resultSet_getBlob(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -292,7 +298,7 @@ public Blob getBlob(String columnLabel) throws SQLException {
@Override
public boolean getBoolean(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- boolean value = chain.resultSet_getBoolean(this, columnIndex);
+ boolean value = chain.resultSet_getBoolean(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -308,7 +314,7 @@ public boolean getBoolean(String columnLabel) throws SQLException {
@Override
public byte getByte(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- byte value = chain.resultSet_getByte(this, columnIndex);
+ byte value = chain.resultSet_getByte(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -324,7 +330,7 @@ public byte getByte(String columnLabel) throws SQLException {
@Override
public byte[] getBytes(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- byte[] value = chain.resultSet_getBytes(this, columnIndex);
+ byte[] value = chain.resultSet_getBytes(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -340,7 +346,7 @@ public byte[] getBytes(String columnLabel) throws SQLException {
@Override
public Reader getCharacterStream(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Reader value = chain.resultSet_getCharacterStream(this, columnIndex);
+ Reader value = chain.resultSet_getCharacterStream(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -356,7 +362,7 @@ public Reader getCharacterStream(String columnLabel) throws SQLException {
@Override
public Clob getClob(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Clob value = chain.resultSet_getClob(this, columnIndex);
+ Clob value = chain.resultSet_getClob(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -388,7 +394,7 @@ public String getCursorName() throws SQLException {
@Override
public Date getDate(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Date value = chain.resultSet_getDate(this, columnIndex);
+ Date value = chain.resultSet_getDate(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -404,7 +410,7 @@ public Date getDate(String columnLabel) throws SQLException {
@Override
public Date getDate(int columnIndex, Calendar cal) throws SQLException {
FilterChainImpl chain = createChain();
- Date value = chain.resultSet_getDate(this, columnIndex, cal);
+ Date value = chain.resultSet_getDate(this, getPhysicalColumn(columnIndex), cal);
recycleFilterChain(chain);
return value;
}
@@ -420,7 +426,7 @@ public Date getDate(String columnLabel, Calendar cal) throws SQLException {
@Override
public double getDouble(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- double value = chain.resultSet_getDouble(this, columnIndex);
+ double value = chain.resultSet_getDouble(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -444,7 +450,7 @@ public int getFetchDirection() throws SQLException {
@Override
public int getFetchSize() throws SQLException {
FilterChainImpl chain = createChain();
- int value = chain.resultSet_getFetchSize(this);
+ int value = chain.resultSet_getFetchSize(this);
recycleFilterChain(chain);
return value;
}
@@ -452,7 +458,7 @@ public int getFetchSize() throws SQLException {
@Override
public float getFloat(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- float value = chain.resultSet_getFloat(this, columnIndex);
+ float value = chain.resultSet_getFloat(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -476,7 +482,7 @@ public int getHoldability() throws SQLException {
@Override
public int getInt(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- int value = chain.resultSet_getInt(this, columnIndex);
+ int value = chain.resultSet_getInt(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -492,7 +498,7 @@ public int getInt(String columnLabel) throws SQLException {
@Override
public long getLong(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- long value = chain.resultSet_getLong(this, columnIndex);
+ long value = chain.resultSet_getLong(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -516,7 +522,7 @@ public ResultSetMetaData getMetaData() throws SQLException {
@Override
public Reader getNCharacterStream(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Reader value = chain.resultSet_getNCharacterStream(this, columnIndex);
+ Reader value = chain.resultSet_getNCharacterStream(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -532,7 +538,7 @@ public Reader getNCharacterStream(String columnLabel) throws SQLException {
@Override
public NClob getNClob(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- NClob value = chain.resultSet_getNClob(this, columnIndex);
+ NClob value = chain.resultSet_getNClob(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -548,7 +554,7 @@ public NClob getNClob(String columnLabel) throws SQLException {
@Override
public String getNString(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- String value = chain.resultSet_getNString(this, columnIndex);
+ String value = chain.resultSet_getNString(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -564,7 +570,7 @@ public String getNString(String columnLabel) throws SQLException {
@Override
public Object getObject(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Object value = chain.resultSet_getObject(this, columnIndex);
+ Object value = chain.resultSet_getObject(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -580,7 +586,7 @@ public Object getObject(String columnLabel) throws SQLException {
@Override
public Object getObject(int columnIndex, Map> map) throws SQLException {
FilterChainImpl chain = createChain();
- Object value = chain.resultSet_getObject(this, columnIndex, map);
+ Object value = chain.resultSet_getObject(this, getPhysicalColumn(columnIndex), map);
recycleFilterChain(chain);
return value;
}
@@ -596,7 +602,7 @@ public Object getObject(String columnLabel, Map> map) throws SQ
@Override
public Ref getRef(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Ref value = chain.resultSet_getRef(this, columnIndex);
+ Ref value = chain.resultSet_getRef(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -620,7 +626,7 @@ public int getRow() throws SQLException {
@Override
public RowId getRowId(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- RowId value = chain.resultSet_getRowId(this, columnIndex);
+ RowId value = chain.resultSet_getRowId(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -636,7 +642,7 @@ public RowId getRowId(String columnLabel) throws SQLException {
@Override
public SQLXML getSQLXML(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- SQLXML value = chain.resultSet_getSQLXML(this, columnIndex);
+ SQLXML value = chain.resultSet_getSQLXML(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -652,7 +658,7 @@ public SQLXML getSQLXML(String columnLabel) throws SQLException {
@Override
public short getShort(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- short value = chain.resultSet_getShort(this, columnIndex);
+ short value = chain.resultSet_getShort(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -676,7 +682,7 @@ public Statement getStatement() throws SQLException {
@Override
public String getString(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- String value = chain.resultSet_getString(this, columnIndex);
+ String value = chain.resultSet_getString(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -692,7 +698,7 @@ public String getString(String columnLabel) throws SQLException {
@Override
public Time getTime(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Time value = chain.resultSet_getTime(this, columnIndex);
+ Time value = chain.resultSet_getTime(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -708,7 +714,7 @@ public Time getTime(String columnLabel) throws SQLException {
@Override
public Time getTime(int columnIndex, Calendar cal) throws SQLException {
FilterChainImpl chain = createChain();
- Time value = chain.resultSet_getTime(this, columnIndex, cal);
+ Time value = chain.resultSet_getTime(this, getPhysicalColumn(columnIndex), cal);
recycleFilterChain(chain);
return value;
}
@@ -724,7 +730,7 @@ public Time getTime(String columnLabel, Calendar cal) throws SQLException {
@Override
public Timestamp getTimestamp(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- Timestamp value = chain.resultSet_getTimestamp(this, columnIndex);
+ Timestamp value = chain.resultSet_getTimestamp(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -740,7 +746,7 @@ public Timestamp getTimestamp(String columnLabel) throws SQLException {
@Override
public Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException {
FilterChainImpl chain = createChain();
- Timestamp value = chain.resultSet_getTimestamp(this, columnIndex, cal);
+ Timestamp value = chain.resultSet_getTimestamp(this, getPhysicalColumn(columnIndex), cal);
recycleFilterChain(chain);
return value;
}
@@ -764,7 +770,7 @@ public int getType() throws SQLException {
@Override
public URL getURL(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- URL value = chain.resultSet_getURL(this, columnIndex);
+ URL value = chain.resultSet_getURL(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -780,7 +786,7 @@ public URL getURL(String columnLabel) throws SQLException {
@Override
public InputStream getUnicodeStream(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- InputStream value = chain.resultSet_getUnicodeStream(this, columnIndex);
+ InputStream value = chain.resultSet_getUnicodeStream(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
return value;
}
@@ -881,7 +887,7 @@ public boolean next() throws SQLException {
fetchRowCount = cursorIndex;
}
}
-
+
recycleFilterChain(chain);
return moreRows;
}
@@ -955,7 +961,7 @@ public void setFetchSize(int rows) throws SQLException {
@Override
public void updateArray(int columnIndex, Array x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateArray(this, columnIndex, x);
+ chain.resultSet_updateArray(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -969,7 +975,7 @@ public void updateArray(String columnLabel, Array x) throws SQLException {
@Override
public void updateAsciiStream(int columnIndex, InputStream x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateAsciiStream(this, columnIndex, x);
+ chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -983,7 +989,7 @@ public void updateAsciiStream(String columnLabel, InputStream x) throws SQLExcep
@Override
public void updateAsciiStream(int columnIndex, InputStream x, int length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateAsciiStream(this, columnIndex, x, length);
+ chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -997,7 +1003,7 @@ public void updateAsciiStream(String columnLabel, InputStream x, int length) thr
@Override
public void updateAsciiStream(int columnIndex, InputStream x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateAsciiStream(this, columnIndex, x, length);
+ chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1011,7 +1017,7 @@ public void updateAsciiStream(String columnLabel, InputStream x, long length) th
@Override
public void updateBigDecimal(int columnIndex, BigDecimal x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBigDecimal(this, columnIndex, x);
+ chain.resultSet_updateBigDecimal(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1025,7 +1031,7 @@ public void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLExcepti
@Override
public void updateBinaryStream(int columnIndex, InputStream x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBinaryStream(this, columnIndex, x);
+ chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1039,7 +1045,7 @@ public void updateBinaryStream(String columnLabel, InputStream x) throws SQLExce
@Override
public void updateBinaryStream(int columnIndex, InputStream x, int length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBinaryStream(this, columnIndex, x, length);
+ chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1053,7 +1059,7 @@ public void updateBinaryStream(String columnLabel, InputStream x, int length) th
@Override
public void updateBinaryStream(int columnIndex, InputStream x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBinaryStream(this, columnIndex, x, length);
+ chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1081,7 +1087,7 @@ public void updateBlob(String columnLabel, Blob x) throws SQLException {
@Override
public void updateBlob(int columnIndex, InputStream x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBlob(this, columnIndex, x);
+ chain.resultSet_updateBlob(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1095,7 +1101,7 @@ public void updateBlob(String columnLabel, InputStream x) throws SQLException {
@Override
public void updateBlob(int columnIndex, InputStream x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBlob(this, columnIndex, x, length);
+ chain.resultSet_updateBlob(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1109,7 +1115,7 @@ public void updateBlob(String columnLabel, InputStream x, long length) throws SQ
@Override
public void updateBoolean(int columnIndex, boolean x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBoolean(this, columnIndex, x);
+ chain.resultSet_updateBoolean(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1123,7 +1129,7 @@ public void updateBoolean(String columnLabel, boolean x) throws SQLException {
@Override
public void updateByte(int columnIndex, byte x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateByte(this, columnIndex, x);
+ chain.resultSet_updateByte(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1137,7 +1143,7 @@ public void updateByte(String columnLabel, byte x) throws SQLException {
@Override
public void updateBytes(int columnIndex, byte[] x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateBytes(this, columnIndex, x);
+ chain.resultSet_updateBytes(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1151,7 +1157,7 @@ public void updateBytes(String columnLabel, byte[] x) throws SQLException {
@Override
public void updateCharacterStream(int columnIndex, Reader x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateCharacterStream(this, columnIndex, x);
+ chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1165,7 +1171,7 @@ public void updateCharacterStream(String columnLabel, Reader x) throws SQLExcept
@Override
public void updateCharacterStream(int columnIndex, Reader x, int length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateCharacterStream(this, columnIndex, x, length);
+ chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1179,7 +1185,7 @@ public void updateCharacterStream(String columnLabel, Reader x, int length) thro
@Override
public void updateCharacterStream(int columnIndex, Reader x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateCharacterStream(this, columnIndex, x, length);
+ chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1193,7 +1199,7 @@ public void updateCharacterStream(String columnLabel, Reader x, long length) thr
@Override
public void updateClob(int columnIndex, Clob x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateClob(this, columnIndex, x);
+ chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1207,7 +1213,7 @@ public void updateClob(String columnLabel, Clob x) throws SQLException {
@Override
public void updateClob(int columnIndex, Reader x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateClob(this, columnIndex, x);
+ chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1221,7 +1227,7 @@ public void updateClob(String columnLabel, Reader x) throws SQLException {
@Override
public void updateClob(int columnIndex, Reader x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateClob(this, columnIndex, x, length);
+ chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1235,7 +1241,7 @@ public void updateClob(String columnLabel, Reader x, long length) throws SQLExce
@Override
public void updateDate(int columnIndex, Date x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateDate(this, columnIndex, x);
+ chain.resultSet_updateDate(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1249,7 +1255,7 @@ public void updateDate(String columnLabel, Date x) throws SQLException {
@Override
public void updateDouble(int columnIndex, double x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateDouble(this, columnIndex, x);
+ chain.resultSet_updateDouble(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1263,7 +1269,7 @@ public void updateDouble(String columnLabel, double x) throws SQLException {
@Override
public void updateFloat(int columnIndex, float x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateFloat(this, columnIndex, x);
+ chain.resultSet_updateFloat(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1277,7 +1283,7 @@ public void updateFloat(String columnLabel, float x) throws SQLException {
@Override
public void updateInt(int columnIndex, int x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateInt(this, columnIndex, x);
+ chain.resultSet_updateInt(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1291,7 +1297,7 @@ public void updateInt(String columnLabel, int x) throws SQLException {
@Override
public void updateLong(int columnIndex, long x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateLong(this, columnIndex, x);
+ chain.resultSet_updateLong(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1305,7 +1311,7 @@ public void updateLong(String columnLabel, long x) throws SQLException {
@Override
public void updateNCharacterStream(int columnIndex, Reader x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateNCharacterStream(this, columnIndex, x);
+ chain.resultSet_updateNCharacterStream(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1319,7 +1325,7 @@ public void updateNCharacterStream(String columnLabel, Reader x) throws SQLExcep
@Override
public void updateNCharacterStream(int columnIndex, Reader x, long length) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateNCharacterStream(this, columnIndex, x, length);
+ chain.resultSet_updateNCharacterStream(this, getPhysicalColumn(columnIndex), x, length);
recycleFilterChain(chain);
}
@@ -1333,7 +1339,7 @@ public void updateNCharacterStream(String columnLabel, Reader x, long length) th
@Override
public void updateNClob(int columnIndex, NClob x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateNClob(this, columnIndex, x);
+ chain.resultSet_updateNClob(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1347,7 +1353,7 @@ public void updateNClob(String columnLabel, NClob x) throws SQLException {
@Override
public void updateNClob(int columnIndex, Reader x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateNClob(this, columnIndex, x);
+ chain.resultSet_updateNClob(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1389,7 +1395,7 @@ public void updateNString(String columnLabel, String x) throws SQLException {
@Override
public void updateNull(int columnIndex) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateNull(this, columnIndex);
+ chain.resultSet_updateNull(this, getPhysicalColumn(columnIndex));
recycleFilterChain(chain);
}
@@ -1403,7 +1409,7 @@ public void updateNull(String columnLabel) throws SQLException {
@Override
public void updateObject(int columnIndex, Object x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateObject(this, columnIndex, x);
+ chain.resultSet_updateObject(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1417,7 +1423,7 @@ public void updateObject(String columnLabel, Object x) throws SQLException {
@Override
public void updateObject(int columnIndex, Object x, int scaleOrLength) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateObject(this, columnIndex, x, scaleOrLength);
+ chain.resultSet_updateObject(this, getPhysicalColumn(columnIndex), x, scaleOrLength);
recycleFilterChain(chain);
}
@@ -1431,7 +1437,7 @@ public void updateObject(String columnLabel, Object x, int scaleOrLength) throws
@Override
public void updateRef(int columnIndex, Ref x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateRef(this, columnIndex, x);
+ chain.resultSet_updateRef(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1452,7 +1458,7 @@ public void updateRow() throws SQLException {
@Override
public void updateRowId(int columnIndex, RowId x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateRowId(this, columnIndex, x);
+ chain.resultSet_updateRowId(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1466,7 +1472,7 @@ public void updateRowId(String columnLabel, RowId x) throws SQLException {
@Override
public void updateSQLXML(int columnIndex, SQLXML x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateSQLXML(this, columnIndex, x);
+ chain.resultSet_updateSQLXML(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1480,7 +1486,7 @@ public void updateSQLXML(String columnLabel, SQLXML x) throws SQLException {
@Override
public void updateShort(int columnIndex, short x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateShort(this, columnIndex, x);
+ chain.resultSet_updateShort(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1494,7 +1500,7 @@ public void updateShort(String columnLabel, short x) throws SQLException {
@Override
public void updateString(int columnIndex, String x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateString(this, columnIndex, x);
+ chain.resultSet_updateString(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1508,7 +1514,7 @@ public void updateString(String columnLabel, String x) throws SQLException {
@Override
public void updateTime(int columnIndex, Time x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateTime(this, columnIndex, x);
+ chain.resultSet_updateTime(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1522,7 +1528,7 @@ public void updateTime(String columnLabel, Time x) throws SQLException {
@Override
public void updateTimestamp(int columnIndex, Timestamp x) throws SQLException {
FilterChainImpl chain = createChain();
- chain.resultSet_updateTimestamp(this, columnIndex, x);
+ chain.resultSet_updateTimestamp(this, getPhysicalColumn(columnIndex), x);
recycleFilterChain(chain);
}
@@ -1537,7 +1543,7 @@ public void updateTimestamp(String columnLabel, Timestamp x) throws SQLException
public boolean wasNull() throws SQLException {
FilterChainImpl chain = createChain();
boolean result = chain.resultSet_wasNull(this);
-
+
recycleFilterChain(chain);
return result;
}
@@ -1599,15 +1605,59 @@ public T unwrap(Class iface) throws SQLException {
if (iface == ResultSetProxy.class || iface == ResultSetProxyImpl.class) {
return (T) this;
}
-
+
return super.unwrap(iface);
}
-
+
public boolean isWrapperFor(Class> iface) throws SQLException {
if (iface == ResultSetProxy.class || iface == ResultSetProxyImpl.class) {
return true;
}
-
+
return super.isWrapperFor(iface);
}
+
+ @Override
+ public int getPhysicalColumn(int logicColumn) {
+ if (logicColumnMap == null) {
+ return logicColumn;
+ }
+ return logicColumnMap.get(logicColumn);
+ }
+
+ private int getLogicColumn(int physicalColumn) {
+ if (physicalColumnMap == null) {
+ return physicalColumn;
+ }
+ return physicalColumnMap.get(physicalColumn);
+ }
+
+ @Override
+ public int getHiddenColumnCount() {
+ if (hiddenColumns == null) {
+ return 0;
+ }
+ return hiddenColumns.size();
+ }
+
+ @Override
+ public List getHiddenColumns() {
+ return this.hiddenColumns;
+ }
+
+ @Override
+ public void setLogicColumnMap(Map logicColumnMap) {
+ this.logicColumnMap = logicColumnMap;
+ }
+
+ @Override
+ public void setPhysicalColumnMap(Map physicalColumnMap) {
+ this.physicalColumnMap = physicalColumnMap;
+ }
+
+ @Override
+ public void setHiddenColumns(List hiddenColumns) {
+ this.hiddenColumns = hiddenColumns;
+ }
+
}
diff --git a/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java b/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java
index fc4ed2d06a..3391b4e856 100644
--- a/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java
+++ b/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java
@@ -56,6 +56,11 @@ public ValuesClause(List values){
}
}
+ public void addValue(SQLExpr value) {
+ value.setParent(this);
+ values.add(value);
+ }
+
public List getValues() {
return values;
}
diff --git a/src/main/java/com/alibaba/druid/wall/WallConfig.java b/src/main/java/com/alibaba/druid/wall/WallConfig.java
index d6d975936a..942e444367 100644
--- a/src/main/java/com/alibaba/druid/wall/WallConfig.java
+++ b/src/main/java/com/alibaba/druid/wall/WallConfig.java
@@ -106,6 +106,7 @@ public class WallConfig implements WallConfigMBean {
private String tenantTablePattern;
private String tenantColumn;
+ private TenantCallBack tenantCallBack;
private boolean wrapAllow = true;
private boolean metadataAllow = true;
@@ -239,6 +240,14 @@ public void setTenantColumn(String tenantColumn) {
this.tenantColumn = tenantColumn;
}
+ public TenantCallBack getTenantCallBack() {
+ return tenantCallBack;
+ }
+
+ public void setTenantCallBack(TenantCallBack tenantCallBack) {
+ this.tenantCallBack = tenantCallBack;
+ }
+
public boolean isMetadataAllow() {
return metadataAllow;
}
@@ -695,4 +704,30 @@ public void setCallAllow(boolean callAllow) {
this.callAllow = callAllow;
}
+ public static abstract interface TenantCallBack {
+
+ public static enum StatementType {
+ SELECT, UPDATE, INSERT, DELETE
+ }
+
+ Object getTenantValue(StatementType statementType, String tableName);
+
+ String getTenantColumn(StatementType statementType, String tableName);
+
+ /**
+ * 返回resultset隐藏列名
+ *
+ * @param tableName
+ * @return
+ */
+ String getHiddenColumn(String tableName);
+
+ /**
+ * resultset返回值中如果包含hiddenColumn的回调函数
+ *
+ * @param value hiddenColumn对应的值
+ */
+ void resultset_hiddenColumn(Object value);
+ }
+
}
diff --git a/src/main/java/com/alibaba/druid/wall/WallFilter.java b/src/main/java/com/alibaba/druid/wall/WallFilter.java
index 279032cc58..982b19385c 100644
--- a/src/main/java/com/alibaba/druid/wall/WallFilter.java
+++ b/src/main/java/com/alibaba/druid/wall/WallFilter.java
@@ -18,10 +18,14 @@
import static com.alibaba.druid.util.Utils.getBoolean;
import java.sql.DatabaseMetaData;
+import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Wrapper;
+import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Properties;
import java.util.Set;
@@ -31,11 +35,15 @@
import com.alibaba.druid.proxy.jdbc.ConnectionProxy;
import com.alibaba.druid.proxy.jdbc.DataSourceProxy;
import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy;
+import com.alibaba.druid.proxy.jdbc.ResultSetMetaDataProxyImpl;
import com.alibaba.druid.proxy.jdbc.ResultSetProxy;
import com.alibaba.druid.proxy.jdbc.StatementProxy;
import com.alibaba.druid.support.logging.Log;
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.JdbcUtils;
+import com.alibaba.druid.util.ServletPathMatcher;
+import com.alibaba.druid.util.StringUtils;
+import com.alibaba.druid.wall.WallConfig.TenantCallBack;
import com.alibaba.druid.wall.spi.DB2WallProvider;
import com.alibaba.druid.wall.spi.MySqlWallProvider;
import com.alibaba.druid.wall.spi.OracleWallProvider;
@@ -469,7 +477,9 @@ public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy s
createWallContext(statement);
try {
sql = check(sql);
- return chain.statement_executeQuery(statement, sql);
+ ResultSetProxy resultSetProxy = chain.statement_executeQuery(statement, sql);
+ preprocessResultSet(resultSetProxy);
+ return resultSetProxy;
} catch (SQLException ex) {
incrementExecuteErrorCount();
throw ex;
@@ -579,7 +589,9 @@ public boolean preparedStatement_execute(FilterChain chain, PreparedStatementPro
public ResultSetProxy preparedStatement_executeQuery(FilterChain chain, PreparedStatementProxy statement)
throws SQLException {
try {
- return chain.preparedStatement_executeQuery(statement);
+ ResultSetProxy resultSetProxy = chain.preparedStatement_executeQuery(statement);
+ preprocessResultSet(resultSetProxy);
+ return resultSetProxy;
} catch (SQLException ex) {
incrementExecuteErrorCount(statement);
throw ex;
@@ -601,6 +613,20 @@ public int preparedStatement_executeUpdate(FilterChain chain, PreparedStatementP
}
}
+ @Override
+ public ResultSetProxy statement_getResultSet(FilterChain chain, StatementProxy statement) throws SQLException {
+ ResultSetProxy resultSetProxy = chain.statement_getResultSet(statement);
+ preprocessResultSet(resultSetProxy);
+ return resultSetProxy;
+ }
+
+ @Override
+ public ResultSetProxy statement_getGeneratedKeys(FilterChain chain, StatementProxy statement) throws SQLException {
+ ResultSetProxy resultSetProxy = chain.statement_getGeneratedKeys(statement);
+ preprocessResultSet(resultSetProxy);
+ return resultSetProxy;
+ }
+
public void setSqlStatAttribute(StatementProxy stmt) {
WallContext context = WallContext.current();
if (context == null) {
@@ -630,7 +656,7 @@ public void statExecuteUpdate(int updateCount) {
provider.addUpdateCount(sqlStat, updateCount);
}
}
-
+
public void incrementExecuteErrorCount(PreparedStatementProxy statement) {
WallSqlStat sqlStat = (WallSqlStat) statement.getAttribute(ATTR_SQL_STAT);
if (sqlStat != null) {
@@ -736,6 +762,34 @@ public void resultSet_close(FilterChain chain, ResultSetProxy resultSet) throws
provider.addFetchRowCount(sqlStat, fetchRowCount);
}
+ // ////////////////
+
+ @Override
+ public boolean resultSet_next(FilterChain chain, ResultSetProxy resultSet) throws SQLException {
+ boolean hasNext = chain.resultSet_next(resultSet);
+ TenantCallBack callback = provider.getConfig().getTenantCallBack();
+ if (callback != null && hasNext) {
+ List hiddenColumns = resultSet.getHiddenColumns();
+ if (hiddenColumns != null && hiddenColumns.size() > 0) {
+ for (Integer columnIndex : hiddenColumns) {
+ Object value = resultSet.getResultSetRaw().getObject(columnIndex);
+ callback.resultset_hiddenColumn(value);
+ }
+ }
+ }
+ return hasNext;
+ }
+
+ @Override
+ public ResultSetMetaData resultSet_getMetaData(FilterChain chain, ResultSetProxy resultSet) throws SQLException {
+ ResultSetMetaData metaData = chain.resultSet_getMetaData(resultSet);
+ if (metaData == null) {
+ return null;
+ }
+
+ return new ResultSetMetaDataProxyImpl(metaData, chain.getDataSource().createMetaDataId(), resultSet);
+ }
+
public long getViolationCount() {
return this.provider.getViolationCount();
}
@@ -751,4 +805,57 @@ public void clearWhiteList() {
public boolean checkValid(String sql) {
return provider.checkValid(sql);
}
+
+ private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException {
+ if (resultSet == null) {
+ return;
+ }
+
+ ResultSetMetaData metaData = resultSet.getResultSetRaw().getMetaData();
+ if (metaData == null) {
+ return;
+ }
+
+ TenantCallBack tenantCallBack = provider.getConfig().getTenantCallBack();
+ String tenantTablePattern = provider.getConfig().getTenantTablePattern();
+ if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
+ return;
+ }
+
+ Map logicColumnMap = new HashMap();
+ Map physicalColumnMap = new HashMap();
+ List hiddenColumns = new ArrayList();
+ for (int physicalColumn = 1, logicColumn = 1; physicalColumn <= metaData.getColumnCount(); physicalColumn++) {
+ boolean isHidden = false;
+ String tableName = metaData.getTableName(physicalColumn);
+
+ String hiddenColumn = null;
+ if (tenantCallBack != null) {
+ hiddenColumn = tenantCallBack.getHiddenColumn(tableName);
+ }
+ if (StringUtils.isEmpty(hiddenColumn)
+ && (tableName == null || ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName))) {
+ hiddenColumn = provider.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(hiddenColumn)) {
+ String columnName = metaData.getColumnName(physicalColumn);
+ if (hiddenColumn.equalsIgnoreCase(columnName)) {
+ hiddenColumns.add(physicalColumn);
+ isHidden = true;
+ }
+ }
+ if (!isHidden) {
+ logicColumnMap.put(logicColumn, physicalColumn);
+ physicalColumnMap.put(physicalColumn, logicColumn);
+ logicColumn++;
+ }
+ }
+
+ if (hiddenColumns.size() > 0) {
+ resultSet.setLogicColumnMap(logicColumnMap);
+ resultSet.setPhysicalColumnMap(physicalColumnMap);
+ resultSet.setHiddenColumns(hiddenColumns);
+ }
+ }
}
diff --git a/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java b/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java
index 3803409e69..0e6c74905d 100644
--- a/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java
+++ b/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java
@@ -27,6 +27,7 @@
import java.util.Enumeration;
import java.util.List;
import java.util.Set;
+import java.util.Stack;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
@@ -71,6 +72,7 @@
import com.alibaba.druid.sql.ast.statement.SQLExprTableSource;
import com.alibaba.druid.sql.ast.statement.SQLInsertInto;
import com.alibaba.druid.sql.ast.statement.SQLInsertStatement;
+import com.alibaba.druid.sql.ast.statement.SQLInsertStatement.ValuesClause;
import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource;
import com.alibaba.druid.sql.ast.statement.SQLRollbackStatement;
import com.alibaba.druid.sql.ast.statement.SQLSelect;
@@ -85,6 +87,7 @@
import com.alibaba.druid.sql.ast.statement.SQLTruncateStatement;
import com.alibaba.druid.sql.ast.statement.SQLUnionOperator;
import com.alibaba.druid.sql.ast.statement.SQLUnionQuery;
+import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem;
import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement;
import com.alibaba.druid.sql.ast.statement.SQLUseStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlBooleanExpr;
@@ -92,6 +95,7 @@
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlCommitStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDescribeStatement;
+import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlReplaceStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSetCharSetStatement;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSetNamesStatement;
@@ -102,6 +106,7 @@
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMergeStatement;
import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMultiInsertStatement;
import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerExecStatement;
+import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerInsertStatement;
import com.alibaba.druid.sql.visitor.ExportParameterVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitor;
import com.alibaba.druid.sql.visitor.SQLEvalVisitorUtils;
@@ -110,7 +115,10 @@
import com.alibaba.druid.support.logging.LogFactory;
import com.alibaba.druid.util.JdbcUtils;
import com.alibaba.druid.util.ServletPathMatcher;
+import com.alibaba.druid.util.StringUtils;
import com.alibaba.druid.wall.WallConfig;
+import com.alibaba.druid.wall.WallConfig.TenantCallBack;
+import com.alibaba.druid.wall.WallConfig.TenantCallBack.StatementType;
import com.alibaba.druid.wall.WallContext;
import com.alibaba.druid.wall.WallProvider;
import com.alibaba.druid.wall.WallSqlTableStat;
@@ -235,6 +243,8 @@ public static void checkInsert(WallVisitor visitor, SQLInsertInto x) {
if (!visitor.getConfig().isInsertAllow()) {
addViolation(visitor, ErrorCode.INSERT_NOT_ALLOW, "insert not allow", x);
}
+
+ checkInsertForMultiTenant(visitor, x);
}
public static void checkSelelct(WallVisitor visitor, SQLSelectQueryBlock x) {
@@ -273,7 +283,8 @@ public static void checkSelelct(WallVisitor visitor, SQLSelectQueryBlock x) {
}
}
}
- checkConditionForMultiTenant(visitor, x.getWhere(), x);
+ checkSelectForMultiTenant(visitor, x);
+ // checkConditionForMultiTenant(visitor, x.getWhere(), x);
}
public static void checkHaving(WallVisitor visitor, SQLExpr x) {
@@ -327,7 +338,7 @@ public static void checkDelete(WallVisitor visitor, SQLDeleteStatement x) {
}
}
- checkConditionForMultiTenant(visitor, x.getWhere(), x);
+ // checkConditionForMultiTenant(visitor, x.getWhere(), x);
}
private static boolean isSimpleConstExpr(SQLExpr sqlExpr) {
@@ -377,6 +388,316 @@ private static void checkCondition(WallVisitor visitor, SQLExpr x) {
}
+ private static void checkJoinSelectForMultiTenant(WallVisitor visitor, SQLJoinTableSource join,
+ SQLSelectQueryBlock x) {
+ TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
+ String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
+ if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
+ return;
+ }
+
+ SQLTableSource right = join.getRight();
+ if (right instanceof SQLExprTableSource) {
+ SQLExpr tableExpr = ((SQLExprTableSource) right).getExpr();
+
+ if (tableExpr instanceof SQLIdentifierExpr) {
+ String tableName = ((SQLIdentifierExpr) tableExpr).getName();
+
+ String alias = null;
+ String tenantColumn = null;
+ if (tenantCallBack != null) {
+ tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
+ }
+
+ if (StringUtils.isEmpty(tenantColumn)
+ && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
+ tenantColumn = visitor.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(tenantColumn)) {
+ alias = right.getAlias();
+ if (alias == null) {
+ alias = tableName;
+ }
+
+ SQLExpr item = null;
+ if (alias != null) {
+ item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
+ } else {
+ item = new SQLIdentifierExpr(tenantColumn);
+ }
+ SQLSelectItem selectItem = new SQLSelectItem(item);
+ x.getSelectList().add(selectItem);
+ visitor.setSqlModified(true);
+ }
+ }
+ }
+ }
+
+ private static boolean isSelectStatmentForMultiTenant(SQLSelectQueryBlock queryBlock) {
+
+ SQLObject parent = queryBlock.getParent();
+ while (parent != null) {
+
+ if (parent instanceof SQLUnionQuery) {
+ SQLObject x = parent;
+ parent = x.getParent();
+ } else {
+ break;
+ }
+ }
+
+ if (!(parent instanceof SQLSelect)) {
+ return false;
+ }
+
+ parent = ((SQLSelect) parent).getParent();
+ if (parent instanceof SQLSelectStatement) {
+ return true;
+ }
+
+ return false;
+ }
+
+ private static void checkSelectForMultiTenant(WallVisitor visitor, SQLSelectQueryBlock x) {
+ TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
+ String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
+ if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
+ return;
+ }
+
+ if (x == null) {
+ throw new IllegalStateException("x is null");
+ }
+
+ if (!isSelectStatmentForMultiTenant(x)) {
+ return;
+ }
+
+ SQLTableSource tableSource = x.getFrom();
+ String alias = null;
+ String matchTableName = null;
+ String tenantColumn = null;
+ if (tableSource instanceof SQLExprTableSource) {
+ SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr();
+
+ if (tableExpr instanceof SQLIdentifierExpr) {
+ String tableName = ((SQLIdentifierExpr) tableExpr).getName();
+
+ if (tenantCallBack != null) {
+ tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
+ }
+
+ if (StringUtils.isEmpty(tenantColumn)
+ && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
+ tenantColumn = visitor.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(tenantColumn)) {
+ matchTableName = tableName;
+ alias = tableSource.getAlias();
+ }
+ }
+ } else if (tableSource instanceof SQLJoinTableSource) {
+ SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
+ if (join.getLeft() instanceof SQLExprTableSource) {
+ SQLExpr tableExpr = ((SQLExprTableSource) join.getLeft()).getExpr();
+
+ if (tableExpr instanceof SQLIdentifierExpr) {
+ String tableName = ((SQLIdentifierExpr) tableExpr).getName();
+
+ if (tenantCallBack != null) {
+ tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName);
+ }
+
+ if (StringUtils.isEmpty(tenantColumn)
+ && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
+ tenantColumn = visitor.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(tenantColumn)) {
+ matchTableName = tableName;
+ alias = join.getLeft().getAlias();
+
+ if (alias == null) {
+ alias = tableName;
+ }
+ }
+ }
+ checkJoinSelectForMultiTenant(visitor, join, x);
+ } else {
+ checkJoinSelectForMultiTenant(visitor, join, x);
+ }
+ }
+
+ if (matchTableName == null) {
+ return;
+ }
+
+ SQLExpr item = null;
+ if (alias != null) {
+ item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
+ } else {
+ item = new SQLIdentifierExpr(tenantColumn);
+ }
+ SQLSelectItem selectItem = new SQLSelectItem(item);
+ x.getSelectList().add(selectItem);
+ visitor.setSqlModified(true);
+ }
+
+ private static void checkUpdateForMultiTenant(WallVisitor visitor, SQLUpdateStatement x) {
+ TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
+ String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
+ if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
+ return;
+ }
+
+ if (x == null) {
+ throw new IllegalStateException("x is null");
+ }
+
+ SQLTableSource tableSource = x.getTableSource();
+ String alias = null;
+ String matchTableName = null;
+ String tenantColumn = null;
+ if (tableSource instanceof SQLExprTableSource) {
+ SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr();
+ if (tableExpr instanceof SQLIdentifierExpr) {
+ String tableName = ((SQLIdentifierExpr) tableExpr).getName();
+
+ if (tenantCallBack != null) {
+ tenantColumn = tenantCallBack.getTenantColumn(StatementType.UPDATE, tableName);
+ }
+ if (StringUtils.isEmpty(tenantColumn)
+ && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
+ tenantColumn = visitor.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(tenantColumn)) {
+ matchTableName = tableName;
+ alias = tableSource.getAlias();
+ }
+ }
+ }
+
+ if (matchTableName == null) {
+ return;
+ }
+
+ SQLExpr item = null;
+ if (alias != null) {
+ item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
+ } else {
+ item = new SQLIdentifierExpr(tenantColumn);
+ }
+ SQLExpr value = generateTenantValue(visitor, alias, StatementType.UPDATE, matchTableName);
+
+ SQLUpdateSetItem updateSetItem = new SQLUpdateSetItem();
+ updateSetItem.setColumn(item);
+ updateSetItem.setValue(value);
+
+ x.getItems().add(updateSetItem);
+ visitor.setSqlModified(true);
+ }
+
+ private static void checkInsertForMultiTenant(WallVisitor visitor, SQLInsertInto x) {
+ TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack();
+ String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
+ if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) {
+ return;
+ }
+
+ if (x == null) {
+ throw new IllegalStateException("x is null");
+ }
+
+ SQLExprTableSource tableSource = x.getTableSource();
+ String alias = null;
+ String matchTableName = null;
+ String tenantColumn = null;
+ SQLExpr tableExpr = tableSource.getExpr();
+ if (tableExpr instanceof SQLIdentifierExpr) {
+ String tableName = ((SQLIdentifierExpr) tableExpr).getName();
+
+ if (tenantCallBack != null) {
+ tenantColumn = tenantCallBack.getTenantColumn(StatementType.INSERT, tableName);
+ }
+ if (StringUtils.isEmpty(tenantColumn)
+ && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) {
+ tenantColumn = visitor.getConfig().getTenantColumn();
+ }
+
+ if (!StringUtils.isEmpty(tenantColumn)) {
+ matchTableName = tableName;
+ alias = tableSource.getAlias();
+ }
+ }
+
+ if (matchTableName == null) {
+ return;
+ }
+
+ SQLExpr item = null;
+ if (alias != null) {
+ item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn);
+ } else {
+ item = new SQLIdentifierExpr(tenantColumn);
+ }
+ SQLExpr value = generateTenantValue(visitor, alias, StatementType.INSERT, matchTableName);
+
+ // add insert item and value
+ x.getColumns().add(item);
+
+ List valuesClauses = null;
+ ValuesClause valuesClause = null;
+ if (x instanceof MySqlInsertStatement) {
+ valuesClauses = ((MySqlInsertStatement) x).getValuesList();
+ } else if (x instanceof SQLServerInsertStatement) {
+ valuesClauses = ((MySqlInsertStatement) x).getValuesList();
+ } else {
+ valuesClause = x.getValues();
+ }
+
+ if (valuesClauses != null && valuesClauses.size() > 0) {
+ for (ValuesClause clause : valuesClauses) {
+ clause.addValue(value);
+ }
+ }
+ if (valuesClause != null) {
+ valuesClause.addValue(value);
+ }
+
+ // insert .. select
+ SQLSelect select = x.getQuery();
+ if (select != null) {
+ List queryBlocks = splitSQLSelectQuery(select.getQuery());
+ for (SQLSelectQueryBlock queryBlock : queryBlocks) {
+ queryBlock.getSelectList().add(new SQLSelectItem(value));
+ }
+ }
+
+ visitor.setSqlModified(true);
+ }
+
+ private static List splitSQLSelectQuery(SQLSelectQuery x) {
+ List groupList = new ArrayList();
+ Stack stack = new Stack();
+
+ stack.push(x);
+ do {
+ SQLSelectQuery query = stack.pop();
+ if (query instanceof SQLSelectQueryBlock) {
+ groupList.add((SQLSelectQueryBlock) query);
+ } else if (query instanceof SQLUnionQuery) {
+ SQLUnionQuery unionQuery = (SQLUnionQuery) query;
+ stack.push(unionQuery.getLeft());
+ stack.push(unionQuery.getRight());
+ }
+ } while (!stack.empty());
+ return groupList;
+ }
+
+ @Deprecated
public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, SQLObject parent) {
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantTablePattern == null || tenantTablePattern.length() == 0) {
@@ -389,12 +710,16 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x,
String alias = null;
SQLTableSource tableSource;
+ StatementType statementType = null;
if (parent instanceof SQLDeleteStatement) {
tableSource = ((SQLDeleteStatement) parent).getTableSource();
+ statementType = StatementType.DELETE;
} else if (parent instanceof SQLUpdateStatement) {
tableSource = ((SQLUpdateStatement) parent).getTableSource();
+ statementType = StatementType.UPDATE;
} else if (parent instanceof SQLSelectQueryBlock) {
tableSource = ((SQLSelectQueryBlock) parent).getFrom();
+ statementType = StatementType.SELECT;
} else {
throw new IllegalStateException("not support parent : " + parent.getClass());
}
@@ -423,9 +748,9 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x,
}
}
- checkJoinConditionForMultiTenant(visitor, join, false);
+ checkJoinConditionForMultiTenant(visitor, join, false, statementType);
} else {
- checkJoinConditionForMultiTenant(visitor, join, true);
+ checkJoinConditionForMultiTenant(visitor, join, true, statementType);
}
}
@@ -433,7 +758,7 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x,
return;
}
- SQLBinaryOpExpr tenantCondition = cretateTenantCondition(visitor, alias);
+ SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, matchTableName);
SQLExpr condition;
if (x == null) {
@@ -457,7 +782,9 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x,
}
}
- public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoinTableSource join, boolean checkLeft) {
+ @Deprecated
+ public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoinTableSource join,
+ boolean checkLeft, StatementType statementType) {
String tenantTablePattern = visitor.getConfig().getTenantTablePattern();
if (tenantTablePattern == null || tenantTablePattern.length() == 0) {
return;
@@ -476,7 +803,7 @@ public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoin
if (alias == null) {
alias = tableName;
}
- SQLBinaryOpExpr tenantCondition = cretateTenantCondition(visitor, alias);
+ SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, tableName);
if (condition == null) {
condition = tenantCondition;
@@ -493,25 +820,39 @@ public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoin
}
}
- private static SQLBinaryOpExpr cretateTenantCondition(WallVisitor visitor, String alias) {
+ @Deprecated
+ private static SQLBinaryOpExpr createTenantCondition(WallVisitor visitor, String alias,
+ StatementType statementType, String tableName) {
SQLExpr left, right;
if (alias != null) {
left = new SQLPropertyExpr(new SQLIdentifierExpr(alias), visitor.getConfig().getTenantColumn());
} else {
left = new SQLIdentifierExpr(visitor.getConfig().getTenantColumn());
}
+ right = generateTenantValue(visitor, alias, statementType, tableName);
+
+ SQLBinaryOpExpr tenantCondition = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right);
+ return tenantCondition;
+ }
+
+ private static SQLExpr generateTenantValue(WallVisitor visitor, String alias, StatementType statementType,
+ String tableName) {
+ SQLExpr value;
+ TenantCallBack callBack = visitor.getConfig().getTenantCallBack();
+ if (callBack != null) {
+ WallProvider.setTenantValue(callBack.getTenantValue(statementType, tableName));
+ }
Object tenantValue = WallProvider.getTenantValue();
if (tenantValue instanceof Number) {
- right = new SQLNumberExpr((Number) tenantValue);
+ value = new SQLNumberExpr((Number) tenantValue);
} else if (tenantValue instanceof String) {
- right = new SQLCharExpr((String) tenantValue);
+ value = new SQLCharExpr((String) tenantValue);
} else {
throw new IllegalStateException("tenant value not support type " + tenantValue);
}
- SQLBinaryOpExpr tenantCondition = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right);
- return tenantCondition;
+ return value;
}
public static void checkReadOnly(WallVisitor visitor, SQLTableSource tableSource) {
@@ -573,7 +914,7 @@ public static void checkUpdate(WallVisitor visitor, SQLUpdateStatement x) {
}
}
- checkConditionForMultiTenant(visitor, where, x);
+ checkUpdateForMultiTenant(visitor, x);
}
public static Object getValue(WallVisitor visitor, SQLBinaryOpExpr x) {
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java
index b368d1a0b5..e684e2fc13 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java
@@ -41,7 +41,7 @@ public void testWall() throws Exception {
Assert.assertFalse(WallUtils.isValidateMySql("SELECT *FROM T UNION select 1 from mysql.user"));
Assert.assertTrue(WallUtils.isValidateMySql("select 'mysql.user'"));
- Assert.assertFalse(WallUtils.isValidateMySql("select 0x3C3F706870206576616C28245F504F53545B2763275D293F3E into outfile '\\www\\edu\\1.php'"));
+ Assert.assertFalse(WallUtils.isValidateMySql("select * FROM T WHERE id = 1 AND select 0x3C3F706870206576616C28245F504F53545B2763275D293F3E into outfile '\\www\\edu\\1.php'"));
Assert.assertTrue(WallUtils.isValidateMySql("select 'outfile'"));
Assert.assertFalse(WallUtils.isValidateMySql("select f1, f2 from t union select 1, 2"));
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java
index 61510e02d7..37475e1942 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java
@@ -139,22 +139,22 @@ public void test_true2() throws Exception {
}
}
- public void test_true3() throws Exception {
+ public void test_false9() throws Exception {
WallProvider provider = initWallProvider();
{
String sql = "select PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, CURRENT_TIMESTAMP() start_time from (SELECT PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, @rank := @rank + 1 AS rank FROM ( SELECT PROJECT_NAME, TABLE_NAME, ( SELECT CASE WHEN GROUP_CONCAT(COLUMN_name) LIKE 'ID,%' THEN SUBSTR( GROUP_CONCAT(COLUMN_name), 4 ) ELSE GROUP_CONCAT(COLUMN_name) END FROM Information_schema.`COLUMNS` A WHERE A.table_name = B.TABLE_NAME ORDER BY ORDINAL_POSITION ) EXPORT_COLUMNS FROM ETL_EXPORT b ORDER BY PROJECT_NAME, TABLE_NAME ) tmp, (SELECT @rank := 0) a) b WHERE rank='2';";
- Assert.assertTrue(provider.checkValid(sql));
+ Assert.assertFalse(provider.checkValid(sql));
}
{
String sql = "select PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, CURRENT_TIMESTAMP() start_time, case when type=1 then ' where day_id = 20130101 ' when type=2 and substr('20130101','7,2')='01' then 'where month_id=201301 ' else 'where 3=5 ' end export_where_data from (SELECT PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, type, @rank := @rank + 1 AS rank FROM ( SELECT PROJECT_NAME, TABLE_NAME, type, ( SELECT CASE WHEN GROUP_CONCAT(COLUMN_name) LIKE 'ID,%' THEN SUBSTR( GROUP_CONCAT(COLUMN_name), 4 ) ELSE GROUP_CONCAT(COLUMN_name) END FROM Information_schema.`COLUMNS` A WHERE A.table_name = concat(B.TABLE_NAME,'_','201301') ORDER BY ORDINAL_POSITION ) EXPORT_COLUMNS FROM ETL_EXPORT b where project_name in ('acc') ORDER BY PROJECT_NAME, TABLE_NAME ) tmp, (SELECT @rank := 0) a) b WHERE rank='3';";
- Assert.assertTrue(provider.checkValid(sql));
+ Assert.assertFalse(provider.checkValid(sql));
}
}
public void test_true4() {
WallProvider provider = initWallProvider();
{
- String sql = "SELECT 10006,@";
+ String sql = "SELECT 10006, @";
Assert.assertTrue(provider.checkValid(sql));
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java
index 7b0d2f5216..977d2e76c5 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java
@@ -45,7 +45,6 @@ public void testMySql() throws Exception {
String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
Assert.assertEquals("DELETE FROM orders" + //
- "\nWHERE tenant = 123" + //
- "\n\tAND FID = ?", resultSql);
+ "\nWHERE FID = ?", resultSql);
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java
new file mode 100644
index 0000000000..13fbaf0d3c
--- /dev/null
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java
@@ -0,0 +1,159 @@
+/*
+ * Copyright 2013 Alibaba Group Holding Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+// Created on 2013-10-17
+// $Id$
+
+package com.alibaba.druid.bvt.filter.wall;
+
+import junit.framework.TestCase;
+
+import org.junit.Assert;
+
+import com.alibaba.druid.sql.SQLUtils;
+import com.alibaba.druid.util.JdbcConstants;
+import com.alibaba.druid.wall.WallCheckResult;
+import com.alibaba.druid.wall.WallConfig;
+import com.alibaba.druid.wall.WallProvider;
+import com.alibaba.druid.wall.spi.MySqlWallProvider;
+
+/**
+ * @author kiki
+ */
+public class TenantInsertTest extends TestCase {
+
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
+
+ protected void setUp() throws Exception {
+ config.setTenantTablePattern("*");
+ config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
+ }
+
+ public void testMySql3() throws Exception {
+ String insert_sql = "INSERT INTO orders (ID, NAME) VALUES (1, \"KIKI\")";
+ String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + //
+ "\nVALUES (1, 'KIKI', 123)";
+ {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ }
+
+ public void testMySql4() throws Exception {
+ String insert_sql = "INSERT INTO orders (ID, NAME) VALUES (1, \"KIKI\"), (1, \"CICI\")";
+ String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + //
+ "\nVALUES (1, 'KIKI', 123)," + //
+ "\n\t(1, 'CICI', 123)";
+
+ {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ }
+
+ public void testMySql5() throws Exception {
+ String insert_sql = "INSERT INTO orders (ID, NAME) SELECT ID, NAME FROM temp WHERE age = 18";
+ String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + //
+ "\nSELECT ID, NAME, 123" + //
+ "\nFROM temp" + //
+ "\nWHERE age = 18";
+
+ {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+ }
+
+ public void testMySql6() throws Exception {
+ String insert_sql = "INSERT INTO orders (ID, NAME) SELECT ID, NAME FROM temp1 WHERE age = 18 UNION SELECT ID, NAME FROM temp2 UNION ALL SELECT ID, NAME FROM temp3";
+ String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + //
+ "\nSELECT ID, NAME, 123" + //
+ "\nFROM temp1" + //
+ "\nWHERE age = 18" + //
+ "\nUNION" + //
+ "\nSELECT ID, NAME, 123" + //
+ "\nFROM temp2" + //
+ "\nUNION ALL" + //
+ "\nSELECT ID, NAME, 123" + //
+ "\nFROM temp3";
+
+ {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config);
+ WallCheckResult checkResult = provider.check(insert_sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+ }
+
+}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java
index 0b0bc4d051..8e82976717 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java
@@ -28,13 +28,19 @@
public class TenantSelectTest extends TestCase {
- private String sql = "SELECT ID, NAME FROM orders WHERE FID = ?";
+ private String sql = "SELECT ID, NAME FROM orders WHERE FID = ?";
+ private String expect_sql = "SELECT ID, NAME, tenant" + //
+ "\nFROM orders" + //
+ "\nWHERE FID = ?";
- private WallConfig config = new WallConfig();
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
protected void setUp() throws Exception {
config.setTenantTablePattern("*");
config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
}
public void testMySql() throws Exception {
@@ -44,9 +50,15 @@ public void testMySql() throws Exception {
Assert.assertEquals(0, checkResult.getViolations().size());
String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
- Assert.assertEquals("SELECT ID, NAME" + //
- "\nFROM orders" + //
- "\nWHERE tenant = 123" + //
- "\n\tAND FID = ?", resultSql);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ public void testMySql2() throws Exception {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java
index 475d251040..4580d5b70d 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java
@@ -28,13 +28,20 @@
public class TenantSelectTest2 extends TestCase {
- private String sql = "SELECT ID, NAME FROM orders WHERE FID = ? OR FID = ?";
+ private String sql = "SELECT ID, NAME FROM orders WHERE FID = ? OR FID = ?";
+ private String expect_sql = "SELECT ID, NAME, tenant" + //
+ "\nFROM orders" + //
+ "\nWHERE FID = ?" + //
+ "\n\tOR FID = ?";
- private WallConfig config = new WallConfig();
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
protected void setUp() throws Exception {
config.setTenantTablePattern("*");
config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
}
public void testMySql() throws Exception {
@@ -44,10 +51,15 @@ public void testMySql() throws Exception {
Assert.assertEquals(0, checkResult.getViolations().size());
String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
- Assert.assertEquals("SELECT ID, NAME" + //
- "\nFROM orders" + //
- "\nWHERE tenant = 123" + //
- "\n\tAND (FID = ?" +
- "\n\t\tOR FID = ?)", resultSql);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ public void testMySql2() throws Exception {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java
index 412def124d..a01167bf0e 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java
@@ -28,15 +28,23 @@
public class TenantSelectTest3 extends TestCase {
- private String sql = "SELECT ID, NAME " + //
- "FROM orders o inner join users u ON o.userid = u.id " + //
- "WHERE FID = ? OR FID = ?";
+ private String sql = "SELECT ID, NAME " + //
+ "FROM orders o inner join users u ON o.userid = u.id " + //
+ "WHERE FID = ? OR FID = ?";
+ private String expect_sql = "SELECT ID, NAME, u.tenant, o.tenant" + //
+ "\nFROM orders o" + //
+ "\n\tINNER JOIN users u ON o.userid = u.id" + //
+ "\nWHERE FID = ?" + //
+ "\n\tOR FID = ?";
- private WallConfig config = new WallConfig();
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
protected void setUp() throws Exception {
config.setTenantTablePattern("*");
config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
}
public void testMySql() throws Exception {
@@ -46,12 +54,15 @@ public void testMySql() throws Exception {
Assert.assertEquals(0, checkResult.getViolations().size());
String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
- Assert.assertEquals("SELECT ID, NAME" + //
- "\nFROM orders o" + //
- "\n\tINNER JOIN users u ON u.tenant = 123" + //
- "\n\t\tAND o.userid = u.id" + //
- "\nWHERE o.tenant = 123" + //
- "\n\tAND (FID = ?" + //
- "\n\t\tOR FID = ?)", resultSql);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ public void testMySql2() throws Exception {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java
index df461be3d8..508ce48f52 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java
@@ -28,15 +28,23 @@
public class TenantSelectTest4 extends TestCase {
- private String sql = "SELECT a.*,b.name " + //
- "FROM vote_info a left join vote_item b on a.item_id=b.id " + //
- "where 1=1 limit 1,10";
+ private String sql = "SELECT a.*,b.name " + //
+ "FROM vote_info a left join vote_item b on a.item_id=b.id " + //
+ "where 1=1 limit 1,10";
+ private String expect_sql = "SELECT a.*, b.name, b.tenant, a.tenant" + //
+ "\nFROM vote_info a" + //
+ "\n\tLEFT JOIN vote_item b ON a.item_id = b.id" + //
+ "\nWHERE 1 = 1" + //
+ "\nLIMIT 1, 10";
- private WallConfig config = new WallConfig();
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
protected void setUp() throws Exception {
config.setTenantTablePattern("*");
config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
}
public void testMySql() throws Exception {
@@ -46,12 +54,15 @@ public void testMySql() throws Exception {
Assert.assertEquals(0, checkResult.getViolations().size());
String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
- Assert.assertEquals("SELECT a.*, b.name" + //
- "\nFROM vote_info a" + //
- "\n\tLEFT JOIN vote_item b ON b.tenant = 123" + //
- "\n\t\tAND a.item_id = b.id" + //
- "\nWHERE a.tenant = 123" + //
- "\n\tAND 1 = 1" + //
- "\nLIMIT 1, 10", resultSql);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ public void testMySql2() throws Exception {
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
}
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java
new file mode 100644
index 0000000000..7debb18557
--- /dev/null
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright 1999-2011 Alibaba Group Holding Ltd.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.alibaba.druid.bvt.filter.wall;
+
+import com.alibaba.druid.wall.WallConfig.TenantCallBack;
+
+public class TenantTestCallBack implements TenantCallBack {
+
+ @Override
+ public Object getTenantValue(StatementType statementType, String tableName) {
+ return 123;
+ }
+
+ @Override
+ public String getTenantColumn(StatementType statementType, String tableName) {
+ return "tenant";
+ }
+
+ @Override
+ public String getHiddenColumn(String tableName) {
+ return "tenant";
+ }
+
+ @Override
+ public void resultset_hiddenColumn(Object value) {
+ System.out.println(value);
+ }
+}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java
index 307d033d9d..02b7e04485 100644
--- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java
@@ -27,27 +27,39 @@
import com.alibaba.druid.wall.spi.MySqlWallProvider;
public class TenantUpdateTest extends TestCase {
- private String sql = "UPDATE T_USER SET FNAME = ? WHERE FID = ?";
-
- private WallConfig config = new WallConfig();
-
- protected void setUp() throws Exception {
- config.setDeleteAllow(false);
- config.setTenantTablePattern("*");
- config.setTenantColumn("tenant");
- }
-
- public void testMySql() throws Exception {
- WallProvider.setTenantValue(123);
- MySqlWallProvider provider = new MySqlWallProvider(config);
- WallCheckResult checkResult = provider.check(sql);
- Assert.assertEquals(0, checkResult.getViolations().size());
-
- String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(),
- JdbcConstants.MYSQL);
- Assert.assertEquals("UPDATE T_USER" + //
- "\nSET FNAME = ?" + //
- "\nWHERE tenant = 123" + //
- "\n\tAND FID = ?", resultSql);
- }
+
+ private String sql = "UPDATE T_USER SET FNAME = ? WHERE FID = ?";
+ private String expect_sql = "UPDATE T_USER" + //
+ "\nSET FNAME = ?, tenant = 123" + //
+ "\nWHERE FID = ?";
+
+ private WallConfig config = new WallConfig();
+ private WallConfig config_callback = new WallConfig();
+
+ protected void setUp() throws Exception {
+ config.setTenantTablePattern("*");
+ config.setTenantColumn("tenant");
+
+ config_callback.setTenantCallBack(new TenantTestCallBack());
+ }
+
+ public void testMySql() throws Exception {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
+
+ public void testMySql2() throws Exception {
+ WallProvider.setTenantValue(123);
+ MySqlWallProvider provider = new MySqlWallProvider(config_callback);
+ WallCheckResult checkResult = provider.check(sql);
+ Assert.assertEquals(0, checkResult.getViolations().size());
+
+ String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL);
+ Assert.assertEquals(expect_sql, resultSql);
+ }
}
diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java
new file mode 100644
index 0000000000..94c876e1a1
--- /dev/null
+++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java
@@ -0,0 +1,400 @@
+package com.alibaba.druid.bvt.filter.wall;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.Statement;
+import java.util.LinkedList;
+import java.util.List;
+
+import junit.framework.TestCase;
+
+import org.junit.Assert;
+
+import com.alibaba.druid.filter.Filter;
+import com.alibaba.druid.pool.DruidDataSource;
+import com.alibaba.druid.util.JdbcConstants;
+import com.alibaba.druid.wall.WallConfig;
+import com.alibaba.druid.wall.WallFilter;
+
+public class WallFilterTest3 extends TestCase {
+
+ private DruidDataSource dataSource;
+ private WallFilter wallFilter;
+
+ protected void setUp() throws Exception {
+ dataSource = new DruidDataSource();
+
+ dataSource.setUrl("jdbc:h2:mem:wall_test;");
+ // dataSource.setFilters("wall");
+ dataSource.setDbType(JdbcConstants.MARIADB);
+
+ WallConfig config = new WallConfig();
+ config.setTenantCallBack(new TenantTestCallBack());
+
+ wallFilter = new WallFilter();
+ wallFilter.setConfig(config);
+ wallFilter.setDbType(JdbcConstants.MARIADB);
+ List filters = new LinkedList();
+ filters.add(wallFilter);
+ dataSource.setProxyFilters(filters);
+
+ dataSource.init();
+
+ }
+
+ protected void tearDown() throws Exception {
+ dataSource.close();
+ }
+
+ public void test_wallFilter() throws Exception {
+ Assert.assertEquals(JdbcConstants.MARIADB, wallFilter.getDbType());
+ Assert.assertFalse(wallFilter.isLogViolation());
+ wallFilter.setLogViolation(true);
+ Assert.assertTrue(wallFilter.isLogViolation());
+ wallFilter.setLogViolation(false);
+ Assert.assertFalse(wallFilter.isLogViolation());
+
+ Assert.assertTrue(wallFilter.isThrowException());
+ wallFilter.setThrowException(false);
+ Assert.assertFalse(wallFilter.isThrowException());
+ wallFilter.setThrowException(true);
+ Assert.assertTrue(wallFilter.isThrowException());
+
+ wallFilter.clearProviderCache();
+ wallFilter.getProviderWhiteList();
+ Assert.assertTrue(wallFilter.isInited());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.execute("CREATE TABLE t (FID INTEGER, FNAME VARCHAR(50), TENANT VARCHAR(32))");
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getCreateCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)";
+
+ for (int i = 0; i < 10; ++i) {
+ PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS);
+ stmt.setInt(1, i + 10);
+ stmt.setString(2, "a" + (i + 10));
+ stmt.execute();
+ stmt.close();
+ }
+
+ conn.close();
+ }
+ Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertCount());
+ Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)";
+
+ PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS);
+ for (int i = 0; i < 10; ++i) {
+ stmt.setInt(1, i + 20);
+ stmt.setString(2, "a" + (i + 20));
+ stmt.addBatch();
+ }
+ stmt.executeBatch();
+ stmt.close();
+
+ conn.close();
+ }
+ Assert.assertEquals(11, wallFilter.getProvider().getTableStat("t").getInsertCount());
+ Assert.assertEquals(20, wallFilter.getProvider().getTableStat("t").getInsertDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ for (int i = 0; i < 10; ++i) {
+ stmt.addBatch("INSERT INTO t (FID, FNAME) VALUES (" + i + ", 'a" + i + "')");
+ }
+ stmt.executeBatch();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(21, wallFilter.getProvider().getTableStat("t").getInsertCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+ {
+ String sql = "SELECT * FROM T";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(30, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY,
+ ResultSet.CONCUR_READ_ONLY,
+ ResultSet.HOLD_CURSORS_OVER_COMMIT);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(60, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement(sql, new int[0]);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(70, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement(sql, new String[0]);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(80, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareCall(sql);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(90, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(100, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T";
+
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
+ ResultSet.HOLD_CURSORS_OVER_COMMIT);
+ ResultSet rs = stmt.executeQuery();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(130, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY);
+ stmt.execute(sql, Statement.NO_GENERATED_KEYS);
+ ResultSet rs = stmt.getResultSet();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(140, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
+ ResultSet.HOLD_CURSORS_OVER_COMMIT);
+ stmt.execute(sql, new int[0]);
+ ResultSet rs = stmt.getResultSet();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(150, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ String sql = "SELECT * FROM T LIMIT 10";
+
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY,
+ ResultSet.HOLD_CURSORS_OVER_COMMIT);
+ stmt.execute(sql, new String[0]);
+ ResultSet rs = stmt.getResultSet();
+ while (rs.next()) {
+
+ }
+ rs.close();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(160, wallFilter.getProvider().getTableStat("t").getFetchRowCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.executeUpdate("DELETE from t where FID = 0");
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.executeUpdate("DELETE from t where FID = 1 OR FID = 2", Statement.NO_GENERATED_KEYS);
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(3, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.executeUpdate("DELETE from t where FID = 3", new int[0]);
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.executeUpdate("DELETE from t where FID = 4", new String[0]);
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(5, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement("DELETE from t where FID = ?");
+ stmt.setInt(1, 5);
+ stmt.executeUpdate();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getDeleteDataCount());
+ Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.execute("update t SET fname = 'xx' where FID = 13 OR FID = 14");
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(2, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?");
+ stmt.setInt(1, 13);
+ stmt.setInt(2, 14);
+ stmt.execute();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?");
+ stmt.setInt(1, 13);
+ stmt.setInt(2, 14);
+ stmt.execute();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+ {
+ Connection conn = dataSource.getConnection();
+ PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ?");
+
+ stmt.setInt(1, 13);
+ stmt.addBatch();
+
+ stmt.setInt(1, 14);
+ stmt.addBatch();
+
+ stmt.executeBatch();
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(8, wallFilter.getProvider().getTableStat("t").getUpdateDataCount());
+
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.execute("truncate table t");
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getTruncateCount());
+ {
+ Connection conn = dataSource.getConnection();
+ Statement stmt = conn.createStatement();
+ stmt.execute("drop table t");
+ stmt.close();
+ conn.close();
+ }
+ Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDropCount());
+
+ Assert.assertEquals(0, wallFilter.getViolationCount());
+ wallFilter.resetViolationCount();
+ wallFilter.checkValid("select 1");
+ Assert.assertEquals(0, wallFilter.getViolationCount());
+ }
+}