diff --git a/pom.xml b/pom.xml index e039c3ac0c..6723aa5873 100644 --- a/pom.xml +++ b/pom.xml @@ -11,7 +11,7 @@ com.alibaba druid - 1.0.1 + 1.0.2 jar druid diff --git a/src/main/java/com/alibaba/druid/VERSION.java b/src/main/java/com/alibaba/druid/VERSION.java index cac253bb26..f08b14b2a0 100644 --- a/src/main/java/com/alibaba/druid/VERSION.java +++ b/src/main/java/com/alibaba/druid/VERSION.java @@ -19,7 +19,7 @@ public final class VERSION { public final static int MajorVersion = 1; public final static int MinorVersion = 0; - public final static int RevisionVersion = 1; + public final static int RevisionVersion = 2; public static String getVersionNumber() { return VERSION.MajorVersion + "." + VERSION.MinorVersion + "." + VERSION.RevisionVersion; diff --git a/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java b/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java index 0964b714b3..3e1b126513 100644 --- a/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java +++ b/src/main/java/com/alibaba/druid/pool/DruidAbstractDataSource.java @@ -61,9 +61,9 @@ import com.alibaba.druid.support.logging.LogFactory; import com.alibaba.druid.util.DruidPasswordCallback; import com.alibaba.druid.util.Histogram; -import com.alibaba.druid.util.Utils; import com.alibaba.druid.util.JdbcUtils; import com.alibaba.druid.util.StringUtils; +import com.alibaba.druid.util.Utils; /** * @author wenshao @@ -1314,6 +1314,7 @@ public void setClearFiltersEnable(boolean clearFiltersEnable) { protected final AtomicLong statementIdSeed = new AtomicLong(20000); protected final AtomicLong resultSetIdSeed = new AtomicLong(50000); protected final AtomicLong transactionIdSeed = new AtomicLong(60000); + protected final AtomicLong metaDataIdSeed = new AtomicLong(80000); public long createConnectionId() { return connectionIdSeed.incrementAndGet(); @@ -1323,6 +1324,10 @@ public long createStatementId() { return statementIdSeed.getAndIncrement(); } + public long createMetaDataId() { + return metaDataIdSeed.getAndIncrement(); + } + public long createResultSetId() { return resultSetIdSeed.getAndIncrement(); } diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java index 960763c774..9c9c6c085c 100644 --- a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxy.java @@ -47,6 +47,8 @@ public interface DataSourceProxy { long createResultSetId(); + long createMetaDataId(); + long createTransactionId(); Properties getConnectProperties(); diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java index 041f998f2b..01d1d20021 100644 --- a/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/DataSourceProxyImpl.java @@ -35,8 +35,8 @@ import com.alibaba.druid.filter.FilterChainImpl; import com.alibaba.druid.stat.JdbcDataSourceStat; import com.alibaba.druid.stat.JdbcStatManager; -import com.alibaba.druid.util.Utils; import com.alibaba.druid.util.JdbcUtils; +import com.alibaba.druid.util.Utils; /** * @author wenshao @@ -58,6 +58,7 @@ public class DataSourceProxyImpl implements DataSourceProxy, DataSourceProxyImpl private final AtomicLong connectionIdSeed = new AtomicLong(10000); private final AtomicLong statementIdSeed = new AtomicLong(20000); private final AtomicLong resultSetIdSeed = new AtomicLong(50000); + private final AtomicLong metaDataIdSeed = new AtomicLong(100000); private final AtomicLong transactionIdSeed = new AtomicLong(0); private final JdbcDataSourceStat dataSourceStat; @@ -364,6 +365,10 @@ public long createResultSetId() { return resultSetIdSeed.getAndIncrement(); } + public long createMetaDataId() { + return metaDataIdSeed.getAndIncrement(); + } + @Override public long createTransactionId() { return transactionIdSeed.getAndIncrement(); diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java new file mode 100644 index 0000000000..df2be10011 --- /dev/null +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxy.java @@ -0,0 +1,28 @@ +/* + * Copyright 1999-2011 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.druid.proxy.jdbc; + +import java.sql.ResultSetMetaData; + +/** + * @author kiki + */ +public interface ResultSetMetaDataProxy extends ResultSetMetaData, WrapperProxy { + + ResultSetMetaData getResultSetMetaDataRaw(); + +} diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java new file mode 100644 index 0000000000..dac5320593 --- /dev/null +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetMetaDataProxyImpl.java @@ -0,0 +1,154 @@ +/* + * Copyright 1999-2011 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.alibaba.druid.proxy.jdbc; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; + +import com.alibaba.druid.filter.FilterChain; + +/** + * @author kiki + */ +public class ResultSetMetaDataProxyImpl extends WrapperProxyImpl implements ResultSetMetaDataProxy { + + private final ResultSetMetaData metaData; + + private final ResultSetProxy resultSetProxy; + + public ResultSetMetaDataProxyImpl(ResultSetMetaData metaData, long id, ResultSetProxy resultSetProxy){ + super(metaData, id); + this.metaData = metaData; + this.resultSetProxy = resultSetProxy; + } + + @Override + public int getColumnCount() throws SQLException { + return metaData.getColumnCount() - resultSetProxy.getHiddenColumnCount(); + } + + @Override + public boolean isAutoIncrement(int column) throws SQLException { + return metaData.isAutoIncrement(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isCaseSensitive(int column) throws SQLException { + return metaData.isCaseSensitive(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isSearchable(int column) throws SQLException { + return metaData.isSearchable(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isCurrency(int column) throws SQLException { + return metaData.isCurrency(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public int isNullable(int column) throws SQLException { + return metaData.isNullable(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isSigned(int column) throws SQLException { + return metaData.isSigned(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public int getColumnDisplaySize(int column) throws SQLException { + return metaData.getColumnDisplaySize(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getColumnLabel(int column) throws SQLException { + return metaData.getColumnLabel(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getColumnName(int column) throws SQLException { + return metaData.getColumnName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getSchemaName(int column) throws SQLException { + return metaData.getSchemaName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public int getPrecision(int column) throws SQLException { + return metaData.getPrecision(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public int getScale(int column) throws SQLException { + return metaData.getScale(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getTableName(int column) throws SQLException { + return metaData.getTableName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getCatalogName(int column) throws SQLException { + return metaData.getCatalogName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public int getColumnType(int column) throws SQLException { + return metaData.getColumnType(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getColumnTypeName(int column) throws SQLException { + return metaData.getColumnTypeName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isReadOnly(int column) throws SQLException { + return metaData.isReadOnly(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isWritable(int column) throws SQLException { + return metaData.isWritable(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public boolean isDefinitelyWritable(int column) throws SQLException { + return metaData.isDefinitelyWritable(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public String getColumnClassName(int column) throws SQLException { + return metaData.getColumnClassName(resultSetProxy.getPhysicalColumn(column)); + } + + @Override + public ResultSetMetaData getResultSetMetaDataRaw() { + return metaData; + } + + @Override + public FilterChain createChain() { + return null; + } + +} diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java index 4b4e327d8d..19881dc44b 100644 --- a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxy.java @@ -16,6 +16,8 @@ package com.alibaba.druid.proxy.jdbc; import java.sql.ResultSet; +import java.util.List; +import java.util.Map; import com.alibaba.druid.stat.JdbcSqlStat; @@ -29,7 +31,7 @@ public interface ResultSetProxy extends ResultSet, WrapperProxy { StatementProxy getStatementProxy(); String getSql(); - + JdbcSqlStat getSqlStat(); int getCursorIndex(); @@ -41,22 +43,35 @@ public interface ResultSetProxy extends ResultSet, WrapperProxy { void setConstructNano(long constructNano); void setConstructNano(); - + int getCloseCount(); - + void addReadStringLength(int length); - + long getReadStringLength(); - + void addReadBytesLength(int length); - + long getReadBytesLength(); - + void incrementOpenInputStreamCount(); - + int getOpenInputStreamCount(); - + void incrementOpenReaderCount(); - + int getOpenReaderCount(); + + int getPhysicalColumn(int logicColumn); + + List getHiddenColumns(); + + int getHiddenColumnCount(); + + void setLogicColumnMap(Map logicColumnMap); + + void setPhysicalColumnMap(Map physicalColumnMap); + + void setHiddenColumns(List hiddenColumns); + } diff --git a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java index 295e56d634..91d58af6fb 100644 --- a/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java +++ b/src/main/java/com/alibaba/druid/proxy/jdbc/ResultSetProxyImpl.java @@ -36,6 +36,7 @@ import java.sql.Time; import java.sql.Timestamp; import java.util.Calendar; +import java.util.List; import java.util.Map; import com.alibaba.druid.filter.FilterChainImpl; @@ -46,23 +47,27 @@ */ public class ResultSetProxyImpl extends WrapperProxyImpl implements ResultSetProxy { - private final ResultSet resultSet; - private final StatementProxy statement; - private final String sql; + private final ResultSet resultSet; + private final StatementProxy statement; + private final String sql; - protected int cursorIndex = 0; - protected int fetchRowCount = 0; - protected long constructNano; - protected final JdbcSqlStat sqlStat; - private int closeCount = 0; + protected int cursorIndex = 0; + protected int fetchRowCount = 0; + protected long constructNano; + protected final JdbcSqlStat sqlStat; + private int closeCount = 0; - private long readStringLength = 0; - private long readBytesLength = 0; + private long readStringLength = 0; + private long readBytesLength = 0; - private int openInputStreamCount = 0; - private int openReaderCount = 0; + private int openInputStreamCount = 0; + private int openReaderCount = 0; - private FilterChainImpl filterChain = null; + private Map logicColumnMap = null; + private Map physicalColumnMap = null; + private List hiddenColumns = null; + + private FilterChainImpl filterChain = null; public ResultSetProxyImpl(StatementProxy statement, ResultSet resultSet, long id, String sql){ super(resultSet, id); @@ -117,10 +122,10 @@ public FilterChainImpl createChain() { } else { this.filterChain = null; } - + return chain; } - + public void recycleFilterChain(FilterChainImpl chain) { chain.reset(); this.filterChain = chain; @@ -180,9 +185,10 @@ public void deleteRow() throws SQLException { @Override public int findColumn(String columnLabel) throws SQLException { FilterChainImpl chain = createChain(); - int value = chain.resultSet_findColumn(this, columnLabel); + int physicalColumn = chain.resultSet_findColumn(this, columnLabel); recycleFilterChain(chain); - return value; + + return getLogicColumn(physicalColumn); } @Override @@ -196,7 +202,7 @@ public boolean first() throws SQLException { @Override public Array getArray(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Array value = chain.resultSet_getArray(this, columnIndex); + Array value = chain.resultSet_getArray(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -212,7 +218,7 @@ public Array getArray(String columnLabel) throws SQLException { @Override public InputStream getAsciiStream(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - InputStream value = chain.resultSet_getAsciiStream(this, columnIndex); + InputStream value = chain.resultSet_getAsciiStream(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -228,7 +234,7 @@ public InputStream getAsciiStream(String columnLabel) throws SQLException { @Override public BigDecimal getBigDecimal(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - BigDecimal value = chain.resultSet_getBigDecimal(this, columnIndex); + BigDecimal value = chain.resultSet_getBigDecimal(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -244,7 +250,7 @@ public BigDecimal getBigDecimal(String columnLabel) throws SQLException { @Override public BigDecimal getBigDecimal(int columnIndex, int scale) throws SQLException { FilterChainImpl chain = createChain(); - BigDecimal value = chain.resultSet_getBigDecimal(this, columnIndex, scale); + BigDecimal value = chain.resultSet_getBigDecimal(this, getPhysicalColumn(columnIndex), scale); recycleFilterChain(chain); return value; } @@ -260,7 +266,7 @@ public BigDecimal getBigDecimal(String columnLabel, int scale) throws SQLExcepti @Override public InputStream getBinaryStream(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - InputStream value = chain.resultSet_getBinaryStream(this, columnIndex); + InputStream value = chain.resultSet_getBinaryStream(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -276,7 +282,7 @@ public InputStream getBinaryStream(String columnLabel) throws SQLException { @Override public Blob getBlob(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Blob value = chain.resultSet_getBlob(this, columnIndex); + Blob value = chain.resultSet_getBlob(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -292,7 +298,7 @@ public Blob getBlob(String columnLabel) throws SQLException { @Override public boolean getBoolean(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - boolean value = chain.resultSet_getBoolean(this, columnIndex); + boolean value = chain.resultSet_getBoolean(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -308,7 +314,7 @@ public boolean getBoolean(String columnLabel) throws SQLException { @Override public byte getByte(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - byte value = chain.resultSet_getByte(this, columnIndex); + byte value = chain.resultSet_getByte(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -324,7 +330,7 @@ public byte getByte(String columnLabel) throws SQLException { @Override public byte[] getBytes(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - byte[] value = chain.resultSet_getBytes(this, columnIndex); + byte[] value = chain.resultSet_getBytes(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -340,7 +346,7 @@ public byte[] getBytes(String columnLabel) throws SQLException { @Override public Reader getCharacterStream(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Reader value = chain.resultSet_getCharacterStream(this, columnIndex); + Reader value = chain.resultSet_getCharacterStream(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -356,7 +362,7 @@ public Reader getCharacterStream(String columnLabel) throws SQLException { @Override public Clob getClob(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Clob value = chain.resultSet_getClob(this, columnIndex); + Clob value = chain.resultSet_getClob(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -388,7 +394,7 @@ public String getCursorName() throws SQLException { @Override public Date getDate(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Date value = chain.resultSet_getDate(this, columnIndex); + Date value = chain.resultSet_getDate(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -404,7 +410,7 @@ public Date getDate(String columnLabel) throws SQLException { @Override public Date getDate(int columnIndex, Calendar cal) throws SQLException { FilterChainImpl chain = createChain(); - Date value = chain.resultSet_getDate(this, columnIndex, cal); + Date value = chain.resultSet_getDate(this, getPhysicalColumn(columnIndex), cal); recycleFilterChain(chain); return value; } @@ -420,7 +426,7 @@ public Date getDate(String columnLabel, Calendar cal) throws SQLException { @Override public double getDouble(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - double value = chain.resultSet_getDouble(this, columnIndex); + double value = chain.resultSet_getDouble(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -444,7 +450,7 @@ public int getFetchDirection() throws SQLException { @Override public int getFetchSize() throws SQLException { FilterChainImpl chain = createChain(); - int value = chain.resultSet_getFetchSize(this); + int value = chain.resultSet_getFetchSize(this); recycleFilterChain(chain); return value; } @@ -452,7 +458,7 @@ public int getFetchSize() throws SQLException { @Override public float getFloat(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - float value = chain.resultSet_getFloat(this, columnIndex); + float value = chain.resultSet_getFloat(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -476,7 +482,7 @@ public int getHoldability() throws SQLException { @Override public int getInt(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - int value = chain.resultSet_getInt(this, columnIndex); + int value = chain.resultSet_getInt(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -492,7 +498,7 @@ public int getInt(String columnLabel) throws SQLException { @Override public long getLong(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - long value = chain.resultSet_getLong(this, columnIndex); + long value = chain.resultSet_getLong(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -516,7 +522,7 @@ public ResultSetMetaData getMetaData() throws SQLException { @Override public Reader getNCharacterStream(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Reader value = chain.resultSet_getNCharacterStream(this, columnIndex); + Reader value = chain.resultSet_getNCharacterStream(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -532,7 +538,7 @@ public Reader getNCharacterStream(String columnLabel) throws SQLException { @Override public NClob getNClob(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - NClob value = chain.resultSet_getNClob(this, columnIndex); + NClob value = chain.resultSet_getNClob(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -548,7 +554,7 @@ public NClob getNClob(String columnLabel) throws SQLException { @Override public String getNString(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - String value = chain.resultSet_getNString(this, columnIndex); + String value = chain.resultSet_getNString(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -564,7 +570,7 @@ public String getNString(String columnLabel) throws SQLException { @Override public Object getObject(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Object value = chain.resultSet_getObject(this, columnIndex); + Object value = chain.resultSet_getObject(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -580,7 +586,7 @@ public Object getObject(String columnLabel) throws SQLException { @Override public Object getObject(int columnIndex, Map> map) throws SQLException { FilterChainImpl chain = createChain(); - Object value = chain.resultSet_getObject(this, columnIndex, map); + Object value = chain.resultSet_getObject(this, getPhysicalColumn(columnIndex), map); recycleFilterChain(chain); return value; } @@ -596,7 +602,7 @@ public Object getObject(String columnLabel, Map> map) throws SQ @Override public Ref getRef(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Ref value = chain.resultSet_getRef(this, columnIndex); + Ref value = chain.resultSet_getRef(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -620,7 +626,7 @@ public int getRow() throws SQLException { @Override public RowId getRowId(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - RowId value = chain.resultSet_getRowId(this, columnIndex); + RowId value = chain.resultSet_getRowId(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -636,7 +642,7 @@ public RowId getRowId(String columnLabel) throws SQLException { @Override public SQLXML getSQLXML(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - SQLXML value = chain.resultSet_getSQLXML(this, columnIndex); + SQLXML value = chain.resultSet_getSQLXML(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -652,7 +658,7 @@ public SQLXML getSQLXML(String columnLabel) throws SQLException { @Override public short getShort(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - short value = chain.resultSet_getShort(this, columnIndex); + short value = chain.resultSet_getShort(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -676,7 +682,7 @@ public Statement getStatement() throws SQLException { @Override public String getString(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - String value = chain.resultSet_getString(this, columnIndex); + String value = chain.resultSet_getString(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -692,7 +698,7 @@ public String getString(String columnLabel) throws SQLException { @Override public Time getTime(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Time value = chain.resultSet_getTime(this, columnIndex); + Time value = chain.resultSet_getTime(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -708,7 +714,7 @@ public Time getTime(String columnLabel) throws SQLException { @Override public Time getTime(int columnIndex, Calendar cal) throws SQLException { FilterChainImpl chain = createChain(); - Time value = chain.resultSet_getTime(this, columnIndex, cal); + Time value = chain.resultSet_getTime(this, getPhysicalColumn(columnIndex), cal); recycleFilterChain(chain); return value; } @@ -724,7 +730,7 @@ public Time getTime(String columnLabel, Calendar cal) throws SQLException { @Override public Timestamp getTimestamp(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - Timestamp value = chain.resultSet_getTimestamp(this, columnIndex); + Timestamp value = chain.resultSet_getTimestamp(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -740,7 +746,7 @@ public Timestamp getTimestamp(String columnLabel) throws SQLException { @Override public Timestamp getTimestamp(int columnIndex, Calendar cal) throws SQLException { FilterChainImpl chain = createChain(); - Timestamp value = chain.resultSet_getTimestamp(this, columnIndex, cal); + Timestamp value = chain.resultSet_getTimestamp(this, getPhysicalColumn(columnIndex), cal); recycleFilterChain(chain); return value; } @@ -764,7 +770,7 @@ public int getType() throws SQLException { @Override public URL getURL(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - URL value = chain.resultSet_getURL(this, columnIndex); + URL value = chain.resultSet_getURL(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -780,7 +786,7 @@ public URL getURL(String columnLabel) throws SQLException { @Override public InputStream getUnicodeStream(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - InputStream value = chain.resultSet_getUnicodeStream(this, columnIndex); + InputStream value = chain.resultSet_getUnicodeStream(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); return value; } @@ -881,7 +887,7 @@ public boolean next() throws SQLException { fetchRowCount = cursorIndex; } } - + recycleFilterChain(chain); return moreRows; } @@ -955,7 +961,7 @@ public void setFetchSize(int rows) throws SQLException { @Override public void updateArray(int columnIndex, Array x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateArray(this, columnIndex, x); + chain.resultSet_updateArray(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -969,7 +975,7 @@ public void updateArray(String columnLabel, Array x) throws SQLException { @Override public void updateAsciiStream(int columnIndex, InputStream x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateAsciiStream(this, columnIndex, x); + chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -983,7 +989,7 @@ public void updateAsciiStream(String columnLabel, InputStream x) throws SQLExcep @Override public void updateAsciiStream(int columnIndex, InputStream x, int length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateAsciiStream(this, columnIndex, x, length); + chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -997,7 +1003,7 @@ public void updateAsciiStream(String columnLabel, InputStream x, int length) thr @Override public void updateAsciiStream(int columnIndex, InputStream x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateAsciiStream(this, columnIndex, x, length); + chain.resultSet_updateAsciiStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1011,7 +1017,7 @@ public void updateAsciiStream(String columnLabel, InputStream x, long length) th @Override public void updateBigDecimal(int columnIndex, BigDecimal x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBigDecimal(this, columnIndex, x); + chain.resultSet_updateBigDecimal(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1025,7 +1031,7 @@ public void updateBigDecimal(String columnLabel, BigDecimal x) throws SQLExcepti @Override public void updateBinaryStream(int columnIndex, InputStream x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBinaryStream(this, columnIndex, x); + chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1039,7 +1045,7 @@ public void updateBinaryStream(String columnLabel, InputStream x) throws SQLExce @Override public void updateBinaryStream(int columnIndex, InputStream x, int length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBinaryStream(this, columnIndex, x, length); + chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1053,7 +1059,7 @@ public void updateBinaryStream(String columnLabel, InputStream x, int length) th @Override public void updateBinaryStream(int columnIndex, InputStream x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBinaryStream(this, columnIndex, x, length); + chain.resultSet_updateBinaryStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1081,7 +1087,7 @@ public void updateBlob(String columnLabel, Blob x) throws SQLException { @Override public void updateBlob(int columnIndex, InputStream x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBlob(this, columnIndex, x); + chain.resultSet_updateBlob(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1095,7 +1101,7 @@ public void updateBlob(String columnLabel, InputStream x) throws SQLException { @Override public void updateBlob(int columnIndex, InputStream x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBlob(this, columnIndex, x, length); + chain.resultSet_updateBlob(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1109,7 +1115,7 @@ public void updateBlob(String columnLabel, InputStream x, long length) throws SQ @Override public void updateBoolean(int columnIndex, boolean x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBoolean(this, columnIndex, x); + chain.resultSet_updateBoolean(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1123,7 +1129,7 @@ public void updateBoolean(String columnLabel, boolean x) throws SQLException { @Override public void updateByte(int columnIndex, byte x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateByte(this, columnIndex, x); + chain.resultSet_updateByte(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1137,7 +1143,7 @@ public void updateByte(String columnLabel, byte x) throws SQLException { @Override public void updateBytes(int columnIndex, byte[] x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateBytes(this, columnIndex, x); + chain.resultSet_updateBytes(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1151,7 +1157,7 @@ public void updateBytes(String columnLabel, byte[] x) throws SQLException { @Override public void updateCharacterStream(int columnIndex, Reader x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateCharacterStream(this, columnIndex, x); + chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1165,7 +1171,7 @@ public void updateCharacterStream(String columnLabel, Reader x) throws SQLExcept @Override public void updateCharacterStream(int columnIndex, Reader x, int length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateCharacterStream(this, columnIndex, x, length); + chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1179,7 +1185,7 @@ public void updateCharacterStream(String columnLabel, Reader x, int length) thro @Override public void updateCharacterStream(int columnIndex, Reader x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateCharacterStream(this, columnIndex, x, length); + chain.resultSet_updateCharacterStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1193,7 +1199,7 @@ public void updateCharacterStream(String columnLabel, Reader x, long length) thr @Override public void updateClob(int columnIndex, Clob x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateClob(this, columnIndex, x); + chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1207,7 +1213,7 @@ public void updateClob(String columnLabel, Clob x) throws SQLException { @Override public void updateClob(int columnIndex, Reader x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateClob(this, columnIndex, x); + chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1221,7 +1227,7 @@ public void updateClob(String columnLabel, Reader x) throws SQLException { @Override public void updateClob(int columnIndex, Reader x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateClob(this, columnIndex, x, length); + chain.resultSet_updateClob(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1235,7 +1241,7 @@ public void updateClob(String columnLabel, Reader x, long length) throws SQLExce @Override public void updateDate(int columnIndex, Date x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateDate(this, columnIndex, x); + chain.resultSet_updateDate(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1249,7 +1255,7 @@ public void updateDate(String columnLabel, Date x) throws SQLException { @Override public void updateDouble(int columnIndex, double x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateDouble(this, columnIndex, x); + chain.resultSet_updateDouble(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1263,7 +1269,7 @@ public void updateDouble(String columnLabel, double x) throws SQLException { @Override public void updateFloat(int columnIndex, float x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateFloat(this, columnIndex, x); + chain.resultSet_updateFloat(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1277,7 +1283,7 @@ public void updateFloat(String columnLabel, float x) throws SQLException { @Override public void updateInt(int columnIndex, int x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateInt(this, columnIndex, x); + chain.resultSet_updateInt(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1291,7 +1297,7 @@ public void updateInt(String columnLabel, int x) throws SQLException { @Override public void updateLong(int columnIndex, long x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateLong(this, columnIndex, x); + chain.resultSet_updateLong(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1305,7 +1311,7 @@ public void updateLong(String columnLabel, long x) throws SQLException { @Override public void updateNCharacterStream(int columnIndex, Reader x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateNCharacterStream(this, columnIndex, x); + chain.resultSet_updateNCharacterStream(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1319,7 +1325,7 @@ public void updateNCharacterStream(String columnLabel, Reader x) throws SQLExcep @Override public void updateNCharacterStream(int columnIndex, Reader x, long length) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateNCharacterStream(this, columnIndex, x, length); + chain.resultSet_updateNCharacterStream(this, getPhysicalColumn(columnIndex), x, length); recycleFilterChain(chain); } @@ -1333,7 +1339,7 @@ public void updateNCharacterStream(String columnLabel, Reader x, long length) th @Override public void updateNClob(int columnIndex, NClob x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateNClob(this, columnIndex, x); + chain.resultSet_updateNClob(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1347,7 +1353,7 @@ public void updateNClob(String columnLabel, NClob x) throws SQLException { @Override public void updateNClob(int columnIndex, Reader x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateNClob(this, columnIndex, x); + chain.resultSet_updateNClob(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1389,7 +1395,7 @@ public void updateNString(String columnLabel, String x) throws SQLException { @Override public void updateNull(int columnIndex) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateNull(this, columnIndex); + chain.resultSet_updateNull(this, getPhysicalColumn(columnIndex)); recycleFilterChain(chain); } @@ -1403,7 +1409,7 @@ public void updateNull(String columnLabel) throws SQLException { @Override public void updateObject(int columnIndex, Object x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateObject(this, columnIndex, x); + chain.resultSet_updateObject(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1417,7 +1423,7 @@ public void updateObject(String columnLabel, Object x) throws SQLException { @Override public void updateObject(int columnIndex, Object x, int scaleOrLength) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateObject(this, columnIndex, x, scaleOrLength); + chain.resultSet_updateObject(this, getPhysicalColumn(columnIndex), x, scaleOrLength); recycleFilterChain(chain); } @@ -1431,7 +1437,7 @@ public void updateObject(String columnLabel, Object x, int scaleOrLength) throws @Override public void updateRef(int columnIndex, Ref x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateRef(this, columnIndex, x); + chain.resultSet_updateRef(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1452,7 +1458,7 @@ public void updateRow() throws SQLException { @Override public void updateRowId(int columnIndex, RowId x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateRowId(this, columnIndex, x); + chain.resultSet_updateRowId(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1466,7 +1472,7 @@ public void updateRowId(String columnLabel, RowId x) throws SQLException { @Override public void updateSQLXML(int columnIndex, SQLXML x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateSQLXML(this, columnIndex, x); + chain.resultSet_updateSQLXML(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1480,7 +1486,7 @@ public void updateSQLXML(String columnLabel, SQLXML x) throws SQLException { @Override public void updateShort(int columnIndex, short x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateShort(this, columnIndex, x); + chain.resultSet_updateShort(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1494,7 +1500,7 @@ public void updateShort(String columnLabel, short x) throws SQLException { @Override public void updateString(int columnIndex, String x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateString(this, columnIndex, x); + chain.resultSet_updateString(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1508,7 +1514,7 @@ public void updateString(String columnLabel, String x) throws SQLException { @Override public void updateTime(int columnIndex, Time x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateTime(this, columnIndex, x); + chain.resultSet_updateTime(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1522,7 +1528,7 @@ public void updateTime(String columnLabel, Time x) throws SQLException { @Override public void updateTimestamp(int columnIndex, Timestamp x) throws SQLException { FilterChainImpl chain = createChain(); - chain.resultSet_updateTimestamp(this, columnIndex, x); + chain.resultSet_updateTimestamp(this, getPhysicalColumn(columnIndex), x); recycleFilterChain(chain); } @@ -1537,7 +1543,7 @@ public void updateTimestamp(String columnLabel, Timestamp x) throws SQLException public boolean wasNull() throws SQLException { FilterChainImpl chain = createChain(); boolean result = chain.resultSet_wasNull(this); - + recycleFilterChain(chain); return result; } @@ -1599,15 +1605,59 @@ public T unwrap(Class iface) throws SQLException { if (iface == ResultSetProxy.class || iface == ResultSetProxyImpl.class) { return (T) this; } - + return super.unwrap(iface); } - + public boolean isWrapperFor(Class iface) throws SQLException { if (iface == ResultSetProxy.class || iface == ResultSetProxyImpl.class) { return true; } - + return super.isWrapperFor(iface); } + + @Override + public int getPhysicalColumn(int logicColumn) { + if (logicColumnMap == null) { + return logicColumn; + } + return logicColumnMap.get(logicColumn); + } + + private int getLogicColumn(int physicalColumn) { + if (physicalColumnMap == null) { + return physicalColumn; + } + return physicalColumnMap.get(physicalColumn); + } + + @Override + public int getHiddenColumnCount() { + if (hiddenColumns == null) { + return 0; + } + return hiddenColumns.size(); + } + + @Override + public List getHiddenColumns() { + return this.hiddenColumns; + } + + @Override + public void setLogicColumnMap(Map logicColumnMap) { + this.logicColumnMap = logicColumnMap; + } + + @Override + public void setPhysicalColumnMap(Map physicalColumnMap) { + this.physicalColumnMap = physicalColumnMap; + } + + @Override + public void setHiddenColumns(List hiddenColumns) { + this.hiddenColumns = hiddenColumns; + } + } diff --git a/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java b/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java index fc4ed2d06a..3391b4e856 100644 --- a/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java +++ b/src/main/java/com/alibaba/druid/sql/ast/statement/SQLInsertStatement.java @@ -56,6 +56,11 @@ public ValuesClause(List values){ } } + public void addValue(SQLExpr value) { + value.setParent(this); + values.add(value); + } + public List getValues() { return values; } diff --git a/src/main/java/com/alibaba/druid/wall/WallConfig.java b/src/main/java/com/alibaba/druid/wall/WallConfig.java index d6d975936a..942e444367 100644 --- a/src/main/java/com/alibaba/druid/wall/WallConfig.java +++ b/src/main/java/com/alibaba/druid/wall/WallConfig.java @@ -106,6 +106,7 @@ public class WallConfig implements WallConfigMBean { private String tenantTablePattern; private String tenantColumn; + private TenantCallBack tenantCallBack; private boolean wrapAllow = true; private boolean metadataAllow = true; @@ -239,6 +240,14 @@ public void setTenantColumn(String tenantColumn) { this.tenantColumn = tenantColumn; } + public TenantCallBack getTenantCallBack() { + return tenantCallBack; + } + + public void setTenantCallBack(TenantCallBack tenantCallBack) { + this.tenantCallBack = tenantCallBack; + } + public boolean isMetadataAllow() { return metadataAllow; } @@ -695,4 +704,30 @@ public void setCallAllow(boolean callAllow) { this.callAllow = callAllow; } + public static abstract interface TenantCallBack { + + public static enum StatementType { + SELECT, UPDATE, INSERT, DELETE + } + + Object getTenantValue(StatementType statementType, String tableName); + + String getTenantColumn(StatementType statementType, String tableName); + + /** + * 返回resultset隐藏列名 + * + * @param tableName + * @return + */ + String getHiddenColumn(String tableName); + + /** + * resultset返回值中如果包含hiddenColumn的回调函数 + * + * @param value hiddenColumn对应的值 + */ + void resultset_hiddenColumn(Object value); + } + } diff --git a/src/main/java/com/alibaba/druid/wall/WallFilter.java b/src/main/java/com/alibaba/druid/wall/WallFilter.java index 279032cc58..982b19385c 100644 --- a/src/main/java/com/alibaba/druid/wall/WallFilter.java +++ b/src/main/java/com/alibaba/druid/wall/WallFilter.java @@ -18,10 +18,14 @@ import static com.alibaba.druid.util.Utils.getBoolean; import java.sql.DatabaseMetaData; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Wrapper; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Properties; import java.util.Set; @@ -31,11 +35,15 @@ import com.alibaba.druid.proxy.jdbc.ConnectionProxy; import com.alibaba.druid.proxy.jdbc.DataSourceProxy; import com.alibaba.druid.proxy.jdbc.PreparedStatementProxy; +import com.alibaba.druid.proxy.jdbc.ResultSetMetaDataProxyImpl; import com.alibaba.druid.proxy.jdbc.ResultSetProxy; import com.alibaba.druid.proxy.jdbc.StatementProxy; import com.alibaba.druid.support.logging.Log; import com.alibaba.druid.support.logging.LogFactory; import com.alibaba.druid.util.JdbcUtils; +import com.alibaba.druid.util.ServletPathMatcher; +import com.alibaba.druid.util.StringUtils; +import com.alibaba.druid.wall.WallConfig.TenantCallBack; import com.alibaba.druid.wall.spi.DB2WallProvider; import com.alibaba.druid.wall.spi.MySqlWallProvider; import com.alibaba.druid.wall.spi.OracleWallProvider; @@ -469,7 +477,9 @@ public ResultSetProxy statement_executeQuery(FilterChain chain, StatementProxy s createWallContext(statement); try { sql = check(sql); - return chain.statement_executeQuery(statement, sql); + ResultSetProxy resultSetProxy = chain.statement_executeQuery(statement, sql); + preprocessResultSet(resultSetProxy); + return resultSetProxy; } catch (SQLException ex) { incrementExecuteErrorCount(); throw ex; @@ -579,7 +589,9 @@ public boolean preparedStatement_execute(FilterChain chain, PreparedStatementPro public ResultSetProxy preparedStatement_executeQuery(FilterChain chain, PreparedStatementProxy statement) throws SQLException { try { - return chain.preparedStatement_executeQuery(statement); + ResultSetProxy resultSetProxy = chain.preparedStatement_executeQuery(statement); + preprocessResultSet(resultSetProxy); + return resultSetProxy; } catch (SQLException ex) { incrementExecuteErrorCount(statement); throw ex; @@ -601,6 +613,20 @@ public int preparedStatement_executeUpdate(FilterChain chain, PreparedStatementP } } + @Override + public ResultSetProxy statement_getResultSet(FilterChain chain, StatementProxy statement) throws SQLException { + ResultSetProxy resultSetProxy = chain.statement_getResultSet(statement); + preprocessResultSet(resultSetProxy); + return resultSetProxy; + } + + @Override + public ResultSetProxy statement_getGeneratedKeys(FilterChain chain, StatementProxy statement) throws SQLException { + ResultSetProxy resultSetProxy = chain.statement_getGeneratedKeys(statement); + preprocessResultSet(resultSetProxy); + return resultSetProxy; + } + public void setSqlStatAttribute(StatementProxy stmt) { WallContext context = WallContext.current(); if (context == null) { @@ -630,7 +656,7 @@ public void statExecuteUpdate(int updateCount) { provider.addUpdateCount(sqlStat, updateCount); } } - + public void incrementExecuteErrorCount(PreparedStatementProxy statement) { WallSqlStat sqlStat = (WallSqlStat) statement.getAttribute(ATTR_SQL_STAT); if (sqlStat != null) { @@ -736,6 +762,34 @@ public void resultSet_close(FilterChain chain, ResultSetProxy resultSet) throws provider.addFetchRowCount(sqlStat, fetchRowCount); } + // //////////////// + + @Override + public boolean resultSet_next(FilterChain chain, ResultSetProxy resultSet) throws SQLException { + boolean hasNext = chain.resultSet_next(resultSet); + TenantCallBack callback = provider.getConfig().getTenantCallBack(); + if (callback != null && hasNext) { + List hiddenColumns = resultSet.getHiddenColumns(); + if (hiddenColumns != null && hiddenColumns.size() > 0) { + for (Integer columnIndex : hiddenColumns) { + Object value = resultSet.getResultSetRaw().getObject(columnIndex); + callback.resultset_hiddenColumn(value); + } + } + } + return hasNext; + } + + @Override + public ResultSetMetaData resultSet_getMetaData(FilterChain chain, ResultSetProxy resultSet) throws SQLException { + ResultSetMetaData metaData = chain.resultSet_getMetaData(resultSet); + if (metaData == null) { + return null; + } + + return new ResultSetMetaDataProxyImpl(metaData, chain.getDataSource().createMetaDataId(), resultSet); + } + public long getViolationCount() { return this.provider.getViolationCount(); } @@ -751,4 +805,57 @@ public void clearWhiteList() { public boolean checkValid(String sql) { return provider.checkValid(sql); } + + private void preprocessResultSet(ResultSetProxy resultSet) throws SQLException { + if (resultSet == null) { + return; + } + + ResultSetMetaData metaData = resultSet.getResultSetRaw().getMetaData(); + if (metaData == null) { + return; + } + + TenantCallBack tenantCallBack = provider.getConfig().getTenantCallBack(); + String tenantTablePattern = provider.getConfig().getTenantTablePattern(); + if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) { + return; + } + + Map logicColumnMap = new HashMap(); + Map physicalColumnMap = new HashMap(); + List hiddenColumns = new ArrayList(); + for (int physicalColumn = 1, logicColumn = 1; physicalColumn <= metaData.getColumnCount(); physicalColumn++) { + boolean isHidden = false; + String tableName = metaData.getTableName(physicalColumn); + + String hiddenColumn = null; + if (tenantCallBack != null) { + hiddenColumn = tenantCallBack.getHiddenColumn(tableName); + } + if (StringUtils.isEmpty(hiddenColumn) + && (tableName == null || ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName))) { + hiddenColumn = provider.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(hiddenColumn)) { + String columnName = metaData.getColumnName(physicalColumn); + if (hiddenColumn.equalsIgnoreCase(columnName)) { + hiddenColumns.add(physicalColumn); + isHidden = true; + } + } + if (!isHidden) { + logicColumnMap.put(logicColumn, physicalColumn); + physicalColumnMap.put(physicalColumn, logicColumn); + logicColumn++; + } + } + + if (hiddenColumns.size() > 0) { + resultSet.setLogicColumnMap(logicColumnMap); + resultSet.setPhysicalColumnMap(physicalColumnMap); + resultSet.setHiddenColumns(hiddenColumns); + } + } } diff --git a/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java b/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java index 3803409e69..0e6c74905d 100644 --- a/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java +++ b/src/main/java/com/alibaba/druid/wall/spi/WallVisitorUtils.java @@ -27,6 +27,7 @@ import java.util.Enumeration; import java.util.List; import java.util.Set; +import java.util.Stack; import com.alibaba.druid.sql.SQLUtils; import com.alibaba.druid.sql.ast.SQLExpr; @@ -71,6 +72,7 @@ import com.alibaba.druid.sql.ast.statement.SQLExprTableSource; import com.alibaba.druid.sql.ast.statement.SQLInsertInto; import com.alibaba.druid.sql.ast.statement.SQLInsertStatement; +import com.alibaba.druid.sql.ast.statement.SQLInsertStatement.ValuesClause; import com.alibaba.druid.sql.ast.statement.SQLJoinTableSource; import com.alibaba.druid.sql.ast.statement.SQLRollbackStatement; import com.alibaba.druid.sql.ast.statement.SQLSelect; @@ -85,6 +87,7 @@ import com.alibaba.druid.sql.ast.statement.SQLTruncateStatement; import com.alibaba.druid.sql.ast.statement.SQLUnionOperator; import com.alibaba.druid.sql.ast.statement.SQLUnionQuery; +import com.alibaba.druid.sql.ast.statement.SQLUpdateSetItem; import com.alibaba.druid.sql.ast.statement.SQLUpdateStatement; import com.alibaba.druid.sql.ast.statement.SQLUseStatement; import com.alibaba.druid.sql.dialect.mysql.ast.expr.MySqlBooleanExpr; @@ -92,6 +95,7 @@ import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlCommitStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDeleteStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlDescribeStatement; +import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlInsertStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlReplaceStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSetCharSetStatement; import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSetNamesStatement; @@ -102,6 +106,7 @@ import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMergeStatement; import com.alibaba.druid.sql.dialect.oracle.ast.stmt.OracleMultiInsertStatement; import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerExecStatement; +import com.alibaba.druid.sql.dialect.sqlserver.ast.stmt.SQLServerInsertStatement; import com.alibaba.druid.sql.visitor.ExportParameterVisitor; import com.alibaba.druid.sql.visitor.SQLEvalVisitor; import com.alibaba.druid.sql.visitor.SQLEvalVisitorUtils; @@ -110,7 +115,10 @@ import com.alibaba.druid.support.logging.LogFactory; import com.alibaba.druid.util.JdbcUtils; import com.alibaba.druid.util.ServletPathMatcher; +import com.alibaba.druid.util.StringUtils; import com.alibaba.druid.wall.WallConfig; +import com.alibaba.druid.wall.WallConfig.TenantCallBack; +import com.alibaba.druid.wall.WallConfig.TenantCallBack.StatementType; import com.alibaba.druid.wall.WallContext; import com.alibaba.druid.wall.WallProvider; import com.alibaba.druid.wall.WallSqlTableStat; @@ -235,6 +243,8 @@ public static void checkInsert(WallVisitor visitor, SQLInsertInto x) { if (!visitor.getConfig().isInsertAllow()) { addViolation(visitor, ErrorCode.INSERT_NOT_ALLOW, "insert not allow", x); } + + checkInsertForMultiTenant(visitor, x); } public static void checkSelelct(WallVisitor visitor, SQLSelectQueryBlock x) { @@ -273,7 +283,8 @@ public static void checkSelelct(WallVisitor visitor, SQLSelectQueryBlock x) { } } } - checkConditionForMultiTenant(visitor, x.getWhere(), x); + checkSelectForMultiTenant(visitor, x); + // checkConditionForMultiTenant(visitor, x.getWhere(), x); } public static void checkHaving(WallVisitor visitor, SQLExpr x) { @@ -327,7 +338,7 @@ public static void checkDelete(WallVisitor visitor, SQLDeleteStatement x) { } } - checkConditionForMultiTenant(visitor, x.getWhere(), x); + // checkConditionForMultiTenant(visitor, x.getWhere(), x); } private static boolean isSimpleConstExpr(SQLExpr sqlExpr) { @@ -377,6 +388,316 @@ private static void checkCondition(WallVisitor visitor, SQLExpr x) { } + private static void checkJoinSelectForMultiTenant(WallVisitor visitor, SQLJoinTableSource join, + SQLSelectQueryBlock x) { + TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack(); + String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); + if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) { + return; + } + + SQLTableSource right = join.getRight(); + if (right instanceof SQLExprTableSource) { + SQLExpr tableExpr = ((SQLExprTableSource) right).getExpr(); + + if (tableExpr instanceof SQLIdentifierExpr) { + String tableName = ((SQLIdentifierExpr) tableExpr).getName(); + + String alias = null; + String tenantColumn = null; + if (tenantCallBack != null) { + tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName); + } + + if (StringUtils.isEmpty(tenantColumn) + && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) { + tenantColumn = visitor.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(tenantColumn)) { + alias = right.getAlias(); + if (alias == null) { + alias = tableName; + } + + SQLExpr item = null; + if (alias != null) { + item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn); + } else { + item = new SQLIdentifierExpr(tenantColumn); + } + SQLSelectItem selectItem = new SQLSelectItem(item); + x.getSelectList().add(selectItem); + visitor.setSqlModified(true); + } + } + } + } + + private static boolean isSelectStatmentForMultiTenant(SQLSelectQueryBlock queryBlock) { + + SQLObject parent = queryBlock.getParent(); + while (parent != null) { + + if (parent instanceof SQLUnionQuery) { + SQLObject x = parent; + parent = x.getParent(); + } else { + break; + } + } + + if (!(parent instanceof SQLSelect)) { + return false; + } + + parent = ((SQLSelect) parent).getParent(); + if (parent instanceof SQLSelectStatement) { + return true; + } + + return false; + } + + private static void checkSelectForMultiTenant(WallVisitor visitor, SQLSelectQueryBlock x) { + TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack(); + String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); + if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) { + return; + } + + if (x == null) { + throw new IllegalStateException("x is null"); + } + + if (!isSelectStatmentForMultiTenant(x)) { + return; + } + + SQLTableSource tableSource = x.getFrom(); + String alias = null; + String matchTableName = null; + String tenantColumn = null; + if (tableSource instanceof SQLExprTableSource) { + SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr(); + + if (tableExpr instanceof SQLIdentifierExpr) { + String tableName = ((SQLIdentifierExpr) tableExpr).getName(); + + if (tenantCallBack != null) { + tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName); + } + + if (StringUtils.isEmpty(tenantColumn) + && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) { + tenantColumn = visitor.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(tenantColumn)) { + matchTableName = tableName; + alias = tableSource.getAlias(); + } + } + } else if (tableSource instanceof SQLJoinTableSource) { + SQLJoinTableSource join = (SQLJoinTableSource) tableSource; + if (join.getLeft() instanceof SQLExprTableSource) { + SQLExpr tableExpr = ((SQLExprTableSource) join.getLeft()).getExpr(); + + if (tableExpr instanceof SQLIdentifierExpr) { + String tableName = ((SQLIdentifierExpr) tableExpr).getName(); + + if (tenantCallBack != null) { + tenantColumn = tenantCallBack.getTenantColumn(StatementType.SELECT, tableName); + } + + if (StringUtils.isEmpty(tenantColumn) + && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) { + tenantColumn = visitor.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(tenantColumn)) { + matchTableName = tableName; + alias = join.getLeft().getAlias(); + + if (alias == null) { + alias = tableName; + } + } + } + checkJoinSelectForMultiTenant(visitor, join, x); + } else { + checkJoinSelectForMultiTenant(visitor, join, x); + } + } + + if (matchTableName == null) { + return; + } + + SQLExpr item = null; + if (alias != null) { + item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn); + } else { + item = new SQLIdentifierExpr(tenantColumn); + } + SQLSelectItem selectItem = new SQLSelectItem(item); + x.getSelectList().add(selectItem); + visitor.setSqlModified(true); + } + + private static void checkUpdateForMultiTenant(WallVisitor visitor, SQLUpdateStatement x) { + TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack(); + String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); + if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) { + return; + } + + if (x == null) { + throw new IllegalStateException("x is null"); + } + + SQLTableSource tableSource = x.getTableSource(); + String alias = null; + String matchTableName = null; + String tenantColumn = null; + if (tableSource instanceof SQLExprTableSource) { + SQLExpr tableExpr = ((SQLExprTableSource) tableSource).getExpr(); + if (tableExpr instanceof SQLIdentifierExpr) { + String tableName = ((SQLIdentifierExpr) tableExpr).getName(); + + if (tenantCallBack != null) { + tenantColumn = tenantCallBack.getTenantColumn(StatementType.UPDATE, tableName); + } + if (StringUtils.isEmpty(tenantColumn) + && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) { + tenantColumn = visitor.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(tenantColumn)) { + matchTableName = tableName; + alias = tableSource.getAlias(); + } + } + } + + if (matchTableName == null) { + return; + } + + SQLExpr item = null; + if (alias != null) { + item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn); + } else { + item = new SQLIdentifierExpr(tenantColumn); + } + SQLExpr value = generateTenantValue(visitor, alias, StatementType.UPDATE, matchTableName); + + SQLUpdateSetItem updateSetItem = new SQLUpdateSetItem(); + updateSetItem.setColumn(item); + updateSetItem.setValue(value); + + x.getItems().add(updateSetItem); + visitor.setSqlModified(true); + } + + private static void checkInsertForMultiTenant(WallVisitor visitor, SQLInsertInto x) { + TenantCallBack tenantCallBack = visitor.getConfig().getTenantCallBack(); + String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); + if (tenantCallBack == null && (tenantTablePattern == null || tenantTablePattern.length() == 0)) { + return; + } + + if (x == null) { + throw new IllegalStateException("x is null"); + } + + SQLExprTableSource tableSource = x.getTableSource(); + String alias = null; + String matchTableName = null; + String tenantColumn = null; + SQLExpr tableExpr = tableSource.getExpr(); + if (tableExpr instanceof SQLIdentifierExpr) { + String tableName = ((SQLIdentifierExpr) tableExpr).getName(); + + if (tenantCallBack != null) { + tenantColumn = tenantCallBack.getTenantColumn(StatementType.INSERT, tableName); + } + if (StringUtils.isEmpty(tenantColumn) + && ServletPathMatcher.getInstance().matches(tenantTablePattern, tableName)) { + tenantColumn = visitor.getConfig().getTenantColumn(); + } + + if (!StringUtils.isEmpty(tenantColumn)) { + matchTableName = tableName; + alias = tableSource.getAlias(); + } + } + + if (matchTableName == null) { + return; + } + + SQLExpr item = null; + if (alias != null) { + item = new SQLPropertyExpr(new SQLIdentifierExpr(alias), tenantColumn); + } else { + item = new SQLIdentifierExpr(tenantColumn); + } + SQLExpr value = generateTenantValue(visitor, alias, StatementType.INSERT, matchTableName); + + // add insert item and value + x.getColumns().add(item); + + List valuesClauses = null; + ValuesClause valuesClause = null; + if (x instanceof MySqlInsertStatement) { + valuesClauses = ((MySqlInsertStatement) x).getValuesList(); + } else if (x instanceof SQLServerInsertStatement) { + valuesClauses = ((MySqlInsertStatement) x).getValuesList(); + } else { + valuesClause = x.getValues(); + } + + if (valuesClauses != null && valuesClauses.size() > 0) { + for (ValuesClause clause : valuesClauses) { + clause.addValue(value); + } + } + if (valuesClause != null) { + valuesClause.addValue(value); + } + + // insert .. select + SQLSelect select = x.getQuery(); + if (select != null) { + List queryBlocks = splitSQLSelectQuery(select.getQuery()); + for (SQLSelectQueryBlock queryBlock : queryBlocks) { + queryBlock.getSelectList().add(new SQLSelectItem(value)); + } + } + + visitor.setSqlModified(true); + } + + private static List splitSQLSelectQuery(SQLSelectQuery x) { + List groupList = new ArrayList(); + Stack stack = new Stack(); + + stack.push(x); + do { + SQLSelectQuery query = stack.pop(); + if (query instanceof SQLSelectQueryBlock) { + groupList.add((SQLSelectQueryBlock) query); + } else if (query instanceof SQLUnionQuery) { + SQLUnionQuery unionQuery = (SQLUnionQuery) query; + stack.push(unionQuery.getLeft()); + stack.push(unionQuery.getRight()); + } + } while (!stack.empty()); + return groupList; + } + + @Deprecated public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, SQLObject parent) { String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); if (tenantTablePattern == null || tenantTablePattern.length() == 0) { @@ -389,12 +710,16 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, String alias = null; SQLTableSource tableSource; + StatementType statementType = null; if (parent instanceof SQLDeleteStatement) { tableSource = ((SQLDeleteStatement) parent).getTableSource(); + statementType = StatementType.DELETE; } else if (parent instanceof SQLUpdateStatement) { tableSource = ((SQLUpdateStatement) parent).getTableSource(); + statementType = StatementType.UPDATE; } else if (parent instanceof SQLSelectQueryBlock) { tableSource = ((SQLSelectQueryBlock) parent).getFrom(); + statementType = StatementType.SELECT; } else { throw new IllegalStateException("not support parent : " + parent.getClass()); } @@ -423,9 +748,9 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, } } - checkJoinConditionForMultiTenant(visitor, join, false); + checkJoinConditionForMultiTenant(visitor, join, false, statementType); } else { - checkJoinConditionForMultiTenant(visitor, join, true); + checkJoinConditionForMultiTenant(visitor, join, true, statementType); } } @@ -433,7 +758,7 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, return; } - SQLBinaryOpExpr tenantCondition = cretateTenantCondition(visitor, alias); + SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, matchTableName); SQLExpr condition; if (x == null) { @@ -457,7 +782,9 @@ public static void checkConditionForMultiTenant(WallVisitor visitor, SQLExpr x, } } - public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoinTableSource join, boolean checkLeft) { + @Deprecated + public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoinTableSource join, + boolean checkLeft, StatementType statementType) { String tenantTablePattern = visitor.getConfig().getTenantTablePattern(); if (tenantTablePattern == null || tenantTablePattern.length() == 0) { return; @@ -476,7 +803,7 @@ public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoin if (alias == null) { alias = tableName; } - SQLBinaryOpExpr tenantCondition = cretateTenantCondition(visitor, alias); + SQLBinaryOpExpr tenantCondition = createTenantCondition(visitor, alias, statementType, tableName); if (condition == null) { condition = tenantCondition; @@ -493,25 +820,39 @@ public static void checkJoinConditionForMultiTenant(WallVisitor visitor, SQLJoin } } - private static SQLBinaryOpExpr cretateTenantCondition(WallVisitor visitor, String alias) { + @Deprecated + private static SQLBinaryOpExpr createTenantCondition(WallVisitor visitor, String alias, + StatementType statementType, String tableName) { SQLExpr left, right; if (alias != null) { left = new SQLPropertyExpr(new SQLIdentifierExpr(alias), visitor.getConfig().getTenantColumn()); } else { left = new SQLIdentifierExpr(visitor.getConfig().getTenantColumn()); } + right = generateTenantValue(visitor, alias, statementType, tableName); + + SQLBinaryOpExpr tenantCondition = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right); + return tenantCondition; + } + + private static SQLExpr generateTenantValue(WallVisitor visitor, String alias, StatementType statementType, + String tableName) { + SQLExpr value; + TenantCallBack callBack = visitor.getConfig().getTenantCallBack(); + if (callBack != null) { + WallProvider.setTenantValue(callBack.getTenantValue(statementType, tableName)); + } Object tenantValue = WallProvider.getTenantValue(); if (tenantValue instanceof Number) { - right = new SQLNumberExpr((Number) tenantValue); + value = new SQLNumberExpr((Number) tenantValue); } else if (tenantValue instanceof String) { - right = new SQLCharExpr((String) tenantValue); + value = new SQLCharExpr((String) tenantValue); } else { throw new IllegalStateException("tenant value not support type " + tenantValue); } - SQLBinaryOpExpr tenantCondition = new SQLBinaryOpExpr(left, SQLBinaryOperator.Equality, right); - return tenantCondition; + return value; } public static void checkReadOnly(WallVisitor visitor, SQLTableSource tableSource) { @@ -573,7 +914,7 @@ public static void checkUpdate(WallVisitor visitor, SQLUpdateStatement x) { } } - checkConditionForMultiTenant(visitor, where, x); + checkUpdateForMultiTenant(visitor, x); } public static Object getValue(WallVisitor visitor, SQLBinaryOpExpr x) { diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java index b368d1a0b5..e684e2fc13 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest.java @@ -41,7 +41,7 @@ public void testWall() throws Exception { Assert.assertFalse(WallUtils.isValidateMySql("SELECT *FROM T UNION select 1 from mysql.user")); Assert.assertTrue(WallUtils.isValidateMySql("select 'mysql.user'")); - Assert.assertFalse(WallUtils.isValidateMySql("select 0x3C3F706870206576616C28245F504F53545B2763275D293F3E into outfile '\\www\\edu\\1.php'")); + Assert.assertFalse(WallUtils.isValidateMySql("select * FROM T WHERE id = 1 AND select 0x3C3F706870206576616C28245F504F53545B2763275D293F3E into outfile '\\www\\edu\\1.php'")); Assert.assertTrue(WallUtils.isValidateMySql("select 'outfile'")); Assert.assertFalse(WallUtils.isValidateMySql("select f1, f2 from t union select 1, 2")); diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java index 61510e02d7..37475e1942 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/MySqlWallTest91.java @@ -139,22 +139,22 @@ public void test_true2() throws Exception { } } - public void test_true3() throws Exception { + public void test_false9() throws Exception { WallProvider provider = initWallProvider(); { String sql = "select PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, CURRENT_TIMESTAMP() start_time from (SELECT PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, @rank := @rank + 1 AS rank FROM ( SELECT PROJECT_NAME, TABLE_NAME, ( SELECT CASE WHEN GROUP_CONCAT(COLUMN_name) LIKE 'ID,%' THEN SUBSTR( GROUP_CONCAT(COLUMN_name), 4 ) ELSE GROUP_CONCAT(COLUMN_name) END FROM Information_schema.`COLUMNS` A WHERE A.table_name = B.TABLE_NAME ORDER BY ORDINAL_POSITION ) EXPORT_COLUMNS FROM ETL_EXPORT b ORDER BY PROJECT_NAME, TABLE_NAME ) tmp, (SELECT @rank := 0) a) b WHERE rank='2';"; - Assert.assertTrue(provider.checkValid(sql)); + Assert.assertFalse(provider.checkValid(sql)); } { String sql = "select PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, CURRENT_TIMESTAMP() start_time, case when type=1 then ' where day_id = 20130101 ' when type=2 and substr('20130101','7,2')='01' then 'where month_id=201301 ' else 'where 3=5 ' end export_where_data from (SELECT PROJECT_NAME, TABLE_NAME, EXPORT_COLUMNS, type, @rank := @rank + 1 AS rank FROM ( SELECT PROJECT_NAME, TABLE_NAME, type, ( SELECT CASE WHEN GROUP_CONCAT(COLUMN_name) LIKE 'ID,%' THEN SUBSTR( GROUP_CONCAT(COLUMN_name), 4 ) ELSE GROUP_CONCAT(COLUMN_name) END FROM Information_schema.`COLUMNS` A WHERE A.table_name = concat(B.TABLE_NAME,'_','201301') ORDER BY ORDINAL_POSITION ) EXPORT_COLUMNS FROM ETL_EXPORT b where project_name in ('acc') ORDER BY PROJECT_NAME, TABLE_NAME ) tmp, (SELECT @rank := 0) a) b WHERE rank='3';"; - Assert.assertTrue(provider.checkValid(sql)); + Assert.assertFalse(provider.checkValid(sql)); } } public void test_true4() { WallProvider provider = initWallProvider(); { - String sql = "SELECT 10006,@"; + String sql = "SELECT 10006, @"; Assert.assertTrue(provider.checkValid(sql)); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java index 7b0d2f5216..977d2e76c5 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantDeleteTest.java @@ -45,7 +45,6 @@ public void testMySql() throws Exception { String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); Assert.assertEquals("DELETE FROM orders" + // - "\nWHERE tenant = 123" + // - "\n\tAND FID = ?", resultSql); + "\nWHERE FID = ?", resultSql); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java new file mode 100644 index 0000000000..13fbaf0d3c --- /dev/null +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantInsertTest.java @@ -0,0 +1,159 @@ +/* + * Copyright 2013 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Created on 2013-10-17 +// $Id$ + +package com.alibaba.druid.bvt.filter.wall; + +import junit.framework.TestCase; + +import org.junit.Assert; + +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.util.JdbcConstants; +import com.alibaba.druid.wall.WallCheckResult; +import com.alibaba.druid.wall.WallConfig; +import com.alibaba.druid.wall.WallProvider; +import com.alibaba.druid.wall.spi.MySqlWallProvider; + +/** + * @author kiki + */ +public class TenantInsertTest extends TestCase { + + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); + + protected void setUp() throws Exception { + config.setTenantTablePattern("*"); + config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); + } + + public void testMySql3() throws Exception { + String insert_sql = "INSERT INTO orders (ID, NAME) VALUES (1, \"KIKI\")"; + String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + // + "\nVALUES (1, 'KIKI', 123)"; + { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + } + + public void testMySql4() throws Exception { + String insert_sql = "INSERT INTO orders (ID, NAME) VALUES (1, \"KIKI\"), (1, \"CICI\")"; + String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + // + "\nVALUES (1, 'KIKI', 123)," + // + "\n\t(1, 'CICI', 123)"; + + { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + } + + public void testMySql5() throws Exception { + String insert_sql = "INSERT INTO orders (ID, NAME) SELECT ID, NAME FROM temp WHERE age = 18"; + String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + // + "\nSELECT ID, NAME, 123" + // + "\nFROM temp" + // + "\nWHERE age = 18"; + + { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + } + + public void testMySql6() throws Exception { + String insert_sql = "INSERT INTO orders (ID, NAME) SELECT ID, NAME FROM temp1 WHERE age = 18 UNION SELECT ID, NAME FROM temp2 UNION ALL SELECT ID, NAME FROM temp3"; + String expect_sql = "INSERT INTO orders (ID, NAME, tenant)" + // + "\nSELECT ID, NAME, 123" + // + "\nFROM temp1" + // + "\nWHERE age = 18" + // + "\nUNION" + // + "\nSELECT ID, NAME, 123" + // + "\nFROM temp2" + // + "\nUNION ALL" + // + "\nSELECT ID, NAME, 123" + // + "\nFROM temp3"; + + { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config); + WallCheckResult checkResult = provider.check(insert_sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + } + +} diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java index 0b0bc4d051..8e82976717 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest.java @@ -28,13 +28,19 @@ public class TenantSelectTest extends TestCase { - private String sql = "SELECT ID, NAME FROM orders WHERE FID = ?"; + private String sql = "SELECT ID, NAME FROM orders WHERE FID = ?"; + private String expect_sql = "SELECT ID, NAME, tenant" + // + "\nFROM orders" + // + "\nWHERE FID = ?"; - private WallConfig config = new WallConfig(); + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); protected void setUp() throws Exception { config.setTenantTablePattern("*"); config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); } public void testMySql() throws Exception { @@ -44,9 +50,15 @@ public void testMySql() throws Exception { Assert.assertEquals(0, checkResult.getViolations().size()); String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); - Assert.assertEquals("SELECT ID, NAME" + // - "\nFROM orders" + // - "\nWHERE tenant = 123" + // - "\n\tAND FID = ?", resultSql); + Assert.assertEquals(expect_sql, resultSql); + } + + public void testMySql2() throws Exception { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java index 475d251040..4580d5b70d 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest2.java @@ -28,13 +28,20 @@ public class TenantSelectTest2 extends TestCase { - private String sql = "SELECT ID, NAME FROM orders WHERE FID = ? OR FID = ?"; + private String sql = "SELECT ID, NAME FROM orders WHERE FID = ? OR FID = ?"; + private String expect_sql = "SELECT ID, NAME, tenant" + // + "\nFROM orders" + // + "\nWHERE FID = ?" + // + "\n\tOR FID = ?"; - private WallConfig config = new WallConfig(); + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); protected void setUp() throws Exception { config.setTenantTablePattern("*"); config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); } public void testMySql() throws Exception { @@ -44,10 +51,15 @@ public void testMySql() throws Exception { Assert.assertEquals(0, checkResult.getViolations().size()); String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); - Assert.assertEquals("SELECT ID, NAME" + // - "\nFROM orders" + // - "\nWHERE tenant = 123" + // - "\n\tAND (FID = ?" + - "\n\t\tOR FID = ?)", resultSql); + Assert.assertEquals(expect_sql, resultSql); + } + + public void testMySql2() throws Exception { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java index 412def124d..a01167bf0e 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest3.java @@ -28,15 +28,23 @@ public class TenantSelectTest3 extends TestCase { - private String sql = "SELECT ID, NAME " + // - "FROM orders o inner join users u ON o.userid = u.id " + // - "WHERE FID = ? OR FID = ?"; + private String sql = "SELECT ID, NAME " + // + "FROM orders o inner join users u ON o.userid = u.id " + // + "WHERE FID = ? OR FID = ?"; + private String expect_sql = "SELECT ID, NAME, u.tenant, o.tenant" + // + "\nFROM orders o" + // + "\n\tINNER JOIN users u ON o.userid = u.id" + // + "\nWHERE FID = ?" + // + "\n\tOR FID = ?"; - private WallConfig config = new WallConfig(); + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); protected void setUp() throws Exception { config.setTenantTablePattern("*"); config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); } public void testMySql() throws Exception { @@ -46,12 +54,15 @@ public void testMySql() throws Exception { Assert.assertEquals(0, checkResult.getViolations().size()); String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); - Assert.assertEquals("SELECT ID, NAME" + // - "\nFROM orders o" + // - "\n\tINNER JOIN users u ON u.tenant = 123" + // - "\n\t\tAND o.userid = u.id" + // - "\nWHERE o.tenant = 123" + // - "\n\tAND (FID = ?" + // - "\n\t\tOR FID = ?)", resultSql); + Assert.assertEquals(expect_sql, resultSql); + } + + public void testMySql2() throws Exception { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java index df461be3d8..508ce48f52 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantSelectTest4.java @@ -28,15 +28,23 @@ public class TenantSelectTest4 extends TestCase { - private String sql = "SELECT a.*,b.name " + // - "FROM vote_info a left join vote_item b on a.item_id=b.id " + // - "where 1=1 limit 1,10"; + private String sql = "SELECT a.*,b.name " + // + "FROM vote_info a left join vote_item b on a.item_id=b.id " + // + "where 1=1 limit 1,10"; + private String expect_sql = "SELECT a.*, b.name, b.tenant, a.tenant" + // + "\nFROM vote_info a" + // + "\n\tLEFT JOIN vote_item b ON a.item_id = b.id" + // + "\nWHERE 1 = 1" + // + "\nLIMIT 1, 10"; - private WallConfig config = new WallConfig(); + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); protected void setUp() throws Exception { config.setTenantTablePattern("*"); config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); } public void testMySql() throws Exception { @@ -46,12 +54,15 @@ public void testMySql() throws Exception { Assert.assertEquals(0, checkResult.getViolations().size()); String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); - Assert.assertEquals("SELECT a.*, b.name" + // - "\nFROM vote_info a" + // - "\n\tLEFT JOIN vote_item b ON b.tenant = 123" + // - "\n\t\tAND a.item_id = b.id" + // - "\nWHERE a.tenant = 123" + // - "\n\tAND 1 = 1" + // - "\nLIMIT 1, 10", resultSql); + Assert.assertEquals(expect_sql, resultSql); + } + + public void testMySql2() throws Exception { + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java new file mode 100644 index 0000000000..7debb18557 --- /dev/null +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantTestCallBack.java @@ -0,0 +1,41 @@ +/* + * Copyright 1999-2011 Alibaba Group Holding Ltd. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.alibaba.druid.bvt.filter.wall; + +import com.alibaba.druid.wall.WallConfig.TenantCallBack; + +public class TenantTestCallBack implements TenantCallBack { + + @Override + public Object getTenantValue(StatementType statementType, String tableName) { + return 123; + } + + @Override + public String getTenantColumn(StatementType statementType, String tableName) { + return "tenant"; + } + + @Override + public String getHiddenColumn(String tableName) { + return "tenant"; + } + + @Override + public void resultset_hiddenColumn(Object value) { + System.out.println(value); + } +} diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java index 307d033d9d..02b7e04485 100644 --- a/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/TenantUpdateTest.java @@ -27,27 +27,39 @@ import com.alibaba.druid.wall.spi.MySqlWallProvider; public class TenantUpdateTest extends TestCase { - private String sql = "UPDATE T_USER SET FNAME = ? WHERE FID = ?"; - - private WallConfig config = new WallConfig(); - - protected void setUp() throws Exception { - config.setDeleteAllow(false); - config.setTenantTablePattern("*"); - config.setTenantColumn("tenant"); - } - - public void testMySql() throws Exception { - WallProvider.setTenantValue(123); - MySqlWallProvider provider = new MySqlWallProvider(config); - WallCheckResult checkResult = provider.check(sql); - Assert.assertEquals(0, checkResult.getViolations().size()); - - String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), - JdbcConstants.MYSQL); - Assert.assertEquals("UPDATE T_USER" + // - "\nSET FNAME = ?" + // - "\nWHERE tenant = 123" + // - "\n\tAND FID = ?", resultSql); - } + + private String sql = "UPDATE T_USER SET FNAME = ? WHERE FID = ?"; + private String expect_sql = "UPDATE T_USER" + // + "\nSET FNAME = ?, tenant = 123" + // + "\nWHERE FID = ?"; + + private WallConfig config = new WallConfig(); + private WallConfig config_callback = new WallConfig(); + + protected void setUp() throws Exception { + config.setTenantTablePattern("*"); + config.setTenantColumn("tenant"); + + config_callback.setTenantCallBack(new TenantTestCallBack()); + } + + public void testMySql() throws Exception { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } + + public void testMySql2() throws Exception { + WallProvider.setTenantValue(123); + MySqlWallProvider provider = new MySqlWallProvider(config_callback); + WallCheckResult checkResult = provider.check(sql); + Assert.assertEquals(0, checkResult.getViolations().size()); + + String resultSql = SQLUtils.toSQLString(checkResult.getStatementList(), JdbcConstants.MYSQL); + Assert.assertEquals(expect_sql, resultSql); + } } diff --git a/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java b/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java new file mode 100644 index 0000000000..94c876e1a1 --- /dev/null +++ b/src/test/java/com/alibaba/druid/bvt/filter/wall/WallFilterTest3.java @@ -0,0 +1,400 @@ +package com.alibaba.druid.bvt.filter.wall; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.LinkedList; +import java.util.List; + +import junit.framework.TestCase; + +import org.junit.Assert; + +import com.alibaba.druid.filter.Filter; +import com.alibaba.druid.pool.DruidDataSource; +import com.alibaba.druid.util.JdbcConstants; +import com.alibaba.druid.wall.WallConfig; +import com.alibaba.druid.wall.WallFilter; + +public class WallFilterTest3 extends TestCase { + + private DruidDataSource dataSource; + private WallFilter wallFilter; + + protected void setUp() throws Exception { + dataSource = new DruidDataSource(); + + dataSource.setUrl("jdbc:h2:mem:wall_test;"); + // dataSource.setFilters("wall"); + dataSource.setDbType(JdbcConstants.MARIADB); + + WallConfig config = new WallConfig(); + config.setTenantCallBack(new TenantTestCallBack()); + + wallFilter = new WallFilter(); + wallFilter.setConfig(config); + wallFilter.setDbType(JdbcConstants.MARIADB); + List filters = new LinkedList(); + filters.add(wallFilter); + dataSource.setProxyFilters(filters); + + dataSource.init(); + + } + + protected void tearDown() throws Exception { + dataSource.close(); + } + + public void test_wallFilter() throws Exception { + Assert.assertEquals(JdbcConstants.MARIADB, wallFilter.getDbType()); + Assert.assertFalse(wallFilter.isLogViolation()); + wallFilter.setLogViolation(true); + Assert.assertTrue(wallFilter.isLogViolation()); + wallFilter.setLogViolation(false); + Assert.assertFalse(wallFilter.isLogViolation()); + + Assert.assertTrue(wallFilter.isThrowException()); + wallFilter.setThrowException(false); + Assert.assertFalse(wallFilter.isThrowException()); + wallFilter.setThrowException(true); + Assert.assertTrue(wallFilter.isThrowException()); + + wallFilter.clearProviderCache(); + wallFilter.getProviderWhiteList(); + Assert.assertTrue(wallFilter.isInited()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("CREATE TABLE t (FID INTEGER, FNAME VARCHAR(50), TENANT VARCHAR(32))"); + stmt.close(); + conn.close(); + } + Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getCreateCount()); + + { + Connection conn = dataSource.getConnection(); + String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)"; + + for (int i = 0; i < 10; ++i) { + PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS); + stmt.setInt(1, i + 10); + stmt.setString(2, "a" + (i + 10)); + stmt.execute(); + stmt.close(); + } + + conn.close(); + } + Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertCount()); + Assert.assertEquals(10, wallFilter.getProvider().getTableStat("t").getInsertDataCount()); + + { + Connection conn = dataSource.getConnection(); + String sql = "INSERT INTO t (FID, FNAME) VALUES (?, ?)"; + + PreparedStatement stmt = conn.prepareStatement(sql, Statement.NO_GENERATED_KEYS); + for (int i = 0; i < 10; ++i) { + stmt.setInt(1, i + 20); + stmt.setString(2, "a" + (i + 20)); + stmt.addBatch(); + } + stmt.executeBatch(); + stmt.close(); + + conn.close(); + } + Assert.assertEquals(11, wallFilter.getProvider().getTableStat("t").getInsertCount()); + Assert.assertEquals(20, wallFilter.getProvider().getTableStat("t").getInsertDataCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + for (int i = 0; i < 10; ++i) { + stmt.addBatch("INSERT INTO t (FID, FNAME) VALUES (" + i + ", 'a" + i + "')"); + } + stmt.executeBatch(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(21, wallFilter.getProvider().getTableStat("t").getInsertCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + { + String sql = "SELECT * FROM T"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(30, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY, + ResultSet.HOLD_CURSORS_OVER_COMMIT); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(60, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql, new int[0]); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(70, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement(sql, new String[0]); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(80, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareCall(sql); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(90, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(100, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T"; + + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareCall(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + ResultSet.HOLD_CURSORS_OVER_COMMIT); + ResultSet rs = stmt.executeQuery(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(130, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + stmt.execute(sql, Statement.NO_GENERATED_KEYS); + ResultSet rs = stmt.getResultSet(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(140, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + ResultSet.HOLD_CURSORS_OVER_COMMIT); + stmt.execute(sql, new int[0]); + ResultSet rs = stmt.getResultSet(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(150, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + String sql = "SELECT * FROM T LIMIT 10"; + + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY, + ResultSet.HOLD_CURSORS_OVER_COMMIT); + stmt.execute(sql, new String[0]); + ResultSet rs = stmt.getResultSet(); + while (rs.next()) { + + } + rs.close(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(160, wallFilter.getProvider().getTableStat("t").getFetchRowCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.executeUpdate("DELETE from t where FID = 0"); + stmt.close(); + conn.close(); + } + Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDeleteDataCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.executeUpdate("DELETE from t where FID = 1 OR FID = 2", Statement.NO_GENERATED_KEYS); + stmt.close(); + conn.close(); + } + Assert.assertEquals(3, wallFilter.getProvider().getTableStat("t").getDeleteDataCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.executeUpdate("DELETE from t where FID = 3", new int[0]); + stmt.close(); + conn.close(); + } + Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getDeleteDataCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.executeUpdate("DELETE from t where FID = 4", new String[0]); + stmt.close(); + conn.close(); + } + Assert.assertEquals(5, wallFilter.getProvider().getTableStat("t").getDeleteDataCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement("DELETE from t where FID = ?"); + stmt.setInt(1, 5); + stmt.executeUpdate(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getDeleteDataCount()); + Assert.assertEquals(0, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("update t SET fname = 'xx' where FID = 13 OR FID = 14"); + stmt.close(); + conn.close(); + } + Assert.assertEquals(2, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?"); + stmt.setInt(1, 13); + stmt.setInt(2, 14); + stmt.execute(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(4, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ? OR FID = ?"); + stmt.setInt(1, 13); + stmt.setInt(2, 14); + stmt.execute(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(6, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + { + Connection conn = dataSource.getConnection(); + PreparedStatement stmt = conn.prepareStatement("update t SET fname = 'xx' where FID = ?"); + + stmt.setInt(1, 13); + stmt.addBatch(); + + stmt.setInt(1, 14); + stmt.addBatch(); + + stmt.executeBatch(); + stmt.close(); + conn.close(); + } + Assert.assertEquals(8, wallFilter.getProvider().getTableStat("t").getUpdateDataCount()); + + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("truncate table t"); + stmt.close(); + conn.close(); + } + Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getTruncateCount()); + { + Connection conn = dataSource.getConnection(); + Statement stmt = conn.createStatement(); + stmt.execute("drop table t"); + stmt.close(); + conn.close(); + } + Assert.assertEquals(1, wallFilter.getProvider().getTableStat("t").getDropCount()); + + Assert.assertEquals(0, wallFilter.getViolationCount()); + wallFilter.resetViolationCount(); + wallFilter.checkValid("select 1"); + Assert.assertEquals(0, wallFilter.getViolationCount()); + } +}