Skip to content

Commit

Permalink
Refactor and simplify insert optimized engine (#3037)
Browse files Browse the repository at this point in the history
* for #2567, StandardRoutingEngine.reviseInsertValue() => setDataNode()

* for #2567, change type of InsertValue.dataNodes

* for #2567, move dataNodes from InsertValue to ShardingCondition

* for #2567, change InsertParameterBuilder.insertParameterUnits from List to Collection

* for #2567, remove getter for InsertParameterBuilder.insertParameterUnits

* for #2567, InsertParameterUnit => InsertParameterGroup

* for #2567, InsertParameterUnit => InsertParameterGroup

* for #2567, move InsertParameterBuilder from org.apache.shardingsphere.core.rewrite.builder to org.apache.shardingsphere.core.rewrite.builder.insert

* for #2567, rename InsertParameterGroup to ValueParametersGroup

* for #2567, refactor ValueParametersGroup class from public to default

* for #2567, add package of rewrite.builder.sql and rewrite.builder.parameter

* for #2567, add package of rewrite.builder.parameter.single and rewrite.builder.parameter.group

* for #2567, rename InsertParameterBuilder to GroupParameterBuilder

* for #2567, generic GroupParameterBuilder

* for #2567, rename BaseParameterBuilder to StandardParameterBuilder

* for #2567, rename GroupParameterBuilder to GroupedParameterBuilder

* for #2567, remove derived encrypt column for sharding condition

* for #2567, add SQLRewriteEngine.encryptInsertValues()

* for #2567, inline SQLRewriteEngine.encryptOptimizedStatement()

* for #2567, refactor ShardingInsertOptimizeEngine.getAllColumnNames()' return type

* for #2567, refactor generated key's rewrite logic, use isGenerated instead of generateKey column judge whether to rewrite

* for #2567, add ShardingInsertOptimizeEngine.getInsertValues()

* for #2567, use statement and metadata to find generated key

* for #2567, add ShardingInsertOptimizeEngine.getColumnNames()

* for #2567, refactor GeneratedKey

* for #2567, refactor ShardingInsertOptimizedStatement's constructor

* for #2567, refactor EncryptInsertOptimizedStatement's constructor
  • Loading branch information
terrymanu authored Sep 15, 2019
1 parent 7761f5d commit 6d2a110
Show file tree
Hide file tree
Showing 50 changed files with 398 additions and 364 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public TableMetaData(final Collection<ColumnMetaData> columnMetaDataList, final
private Map<String, ColumnMetaData> getColumns(final Collection<ColumnMetaData> columnMetaDataList) {
Map<String, ColumnMetaData> columns = new LinkedHashMap<>(columnMetaDataList.size(), 1);
for (ColumnMetaData each : columnMetaDataList) {
columns.put(each.getName(), each);
columns.put(each.getName().toLowerCase(), each);
}
return Collections.synchronizedMap(columns);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.shardingsphere.core.merge.dal.DALMergeEngine;
import org.apache.shardingsphere.core.merge.dql.DQLMergeEngine;
import org.apache.shardingsphere.core.merge.fixture.TestQueryResult;
import org.apache.shardingsphere.core.optimize.api.segment.InsertValue;
import org.apache.shardingsphere.core.optimize.encrypt.condition.EncryptCondition;
import org.apache.shardingsphere.core.optimize.encrypt.statement.EncryptTransparentOptimizedStatement;
import org.apache.shardingsphere.core.optimize.sharding.segment.condition.ShardingCondition;
Expand Down Expand Up @@ -89,8 +90,9 @@ public void assertNewInstanceWithDALStatement() throws SQLException {

@Test
public void assertNewInstanceWithOtherStatement() throws SQLException {
SQLRouteResult routeResult = new SQLRouteResult(new ShardingInsertOptimizedStatement(new InsertStatement(),
Collections.<ShardingCondition>emptyList(), Collections.<String>emptyList(), null), new EncryptTransparentOptimizedStatement(new InsertStatement()));
SQLRouteResult routeResult = new SQLRouteResult(
new ShardingInsertOptimizedStatement(new InsertStatement(), Collections.<ShardingCondition>emptyList(), Collections.<String>emptyList(), null, Collections.<InsertValue>emptyList()),
new EncryptTransparentOptimizedStatement(new InsertStatement()));
assertThat(MergeEngineFactory.newInstance(DatabaseTypes.getActualDatabaseType("MySQL"), null, routeResult, null, queryResults), instanceOf(TransparentMergeEngine.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.core.rule.DataNode;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
Expand All @@ -40,7 +38,7 @@
* @author zhangliang
*/
@Getter
@ToString(exclude = "dataNodes")
@ToString
public final class InsertValue {

private final int parametersCount;
Expand All @@ -49,8 +47,6 @@ public final class InsertValue {

private final List<Object> parameters;

private final List<DataNode> dataNodes = new LinkedList<>();

public InsertValue(final Collection<ExpressionSegment> assignments, final int derivedColumnsCount, final List<Object> parameters, final int parametersOffset) {
parametersCount = calculateParametersCount(assignments);
valueExpressions = getValueExpressions(assignments, derivedColumnsCount);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.shardingsphere.core.rule.EncryptRule;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;

/**
Expand All @@ -37,13 +38,20 @@ public final class EncryptInsertOptimizeEngine implements EncryptOptimizeEngine<

@Override
public EncryptInsertOptimizedStatement optimize(final EncryptRule encryptRule, final TableMetas tableMetas, final String sql, final List<Object> parameters, final InsertStatement sqlStatement) {
String tableName = sqlStatement.getTable().getTableName();
EncryptInsertOptimizedStatement result = new EncryptInsertOptimizedStatement(sqlStatement, tableMetas);
int derivedColumnsCount = encryptRule.getAssistedQueryAndPlainColumns(tableName).size();
int derivedColumnsCount = getDerivedColumnsCount(encryptRule, sqlStatement.getTable().getTableName());
return new EncryptInsertOptimizedStatement(sqlStatement, tableMetas, getInsertValues(parameters, sqlStatement, derivedColumnsCount));
}

private int getDerivedColumnsCount(final EncryptRule encryptRule, final String tableName) {
return encryptRule.getAssistedQueryAndPlainColumns(tableName).size();
}

private List<InsertValue> getInsertValues(final List<Object> parameters, final InsertStatement sqlStatement, final int derivedColumnsCount) {
List<InsertValue> result = new LinkedList<>();
int parametersOffset = 0;
for (Collection<ExpressionSegment> each : sqlStatement.getAllValueExpressions()) {
InsertValue insertValue = new InsertValue(each, derivedColumnsCount, parameters, parametersOffset);
result.getInsertValues().add(insertValue);
result.add(insertValue);
parametersOffset += insertValue.getParametersCount();
}
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.shardingsphere.core.parse.sql.statement.SQLStatement;
import org.apache.shardingsphere.core.parse.sql.statement.dml.InsertStatement;

import java.util.LinkedList;
import java.util.List;

/**
Expand All @@ -46,12 +45,13 @@ public final class EncryptInsertOptimizedStatement implements InsertOptimizedSta

private final List<String> columnNames;

private final List<InsertValue> insertValues = new LinkedList<>();
private final List<InsertValue> insertValues;

public EncryptInsertOptimizedStatement(final InsertStatement sqlStatement, final TableMetas tableMetas) {
public EncryptInsertOptimizedStatement(final InsertStatement sqlStatement, final TableMetas tableMetas, final List<InsertValue> insertValues) {
this.sqlStatement = sqlStatement;
tables = new Tables(sqlStatement);
columnNames = sqlStatement.useDefaultColumns() ? tableMetas.getAllColumnNames(sqlStatement.getTable().getTableName()) : sqlStatement.getColumnNames();
this.insertValues = insertValues;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,41 +46,15 @@
public final class ShardingInsertOptimizeEngine implements ShardingOptimizeEngine<InsertStatement> {

@Override
public ShardingInsertOptimizedStatement optimize(final ShardingRule shardingRule,
public ShardingInsertOptimizedStatement optimize(final ShardingRule shardingRule,
final TableMetas tableMetas, final String sql, final List<Object> parameters, final InsertStatement sqlStatement) {
String tableName = sqlStatement.getTable().getTableName();
List<String> columnNames = sqlStatement.useDefaultColumns() ? tableMetas.getAllColumnNames(tableName) : sqlStatement.getColumnNames();
Optional<GeneratedKey> generatedKey = GeneratedKey.getGenerateKey(shardingRule, parameters, sqlStatement, columnNames);
boolean isGeneratedValue = generatedKey.isPresent() && generatedKey.get().isGenerated();
if (isGeneratedValue) {
columnNames.remove(generatedKey.get().getColumnName());
}
List<String> allColumnNames = getAllColumnNames(columnNames, generatedKey.orNull(), shardingRule.getEncryptRule().getAssistedQueryAndPlainColumns(tableName));
List<ShardingCondition> shardingConditions = new InsertClauseShardingConditionEngine(shardingRule).createShardingConditions(sqlStatement, parameters, allColumnNames, generatedKey.orNull());
ShardingInsertOptimizedStatement result = new ShardingInsertOptimizedStatement(sqlStatement, shardingConditions, columnNames, generatedKey.orNull());
checkDuplicateKeyForShardingKey(shardingRule, sqlStatement, tableName);
int derivedColumnsCount = getDerivedColumnsCount(shardingRule, tableName, isGeneratedValue);
int parametersOffset = 0;
for (Collection<ExpressionSegment> each : sqlStatement.getAllValueExpressions()) {
InsertValue insertValue = new InsertValue(each, derivedColumnsCount, parameters, parametersOffset);
result.getInsertValues().add(insertValue);
parametersOffset += insertValue.getParametersCount();
}
return result;
}

private int getDerivedColumnsCount(final ShardingRule shardingRule, final String tableName, final boolean isGeneratedValue) {
int encryptDerivedColumnsCount = shardingRule.getEncryptRule().getAssistedQueryAndPlainColumns(tableName).size();
return isGeneratedValue ? encryptDerivedColumnsCount + 1 : encryptDerivedColumnsCount;
}

private List<String> getAllColumnNames(final Collection<String> columnNames, final GeneratedKey generatedKey, final Collection<String> derivedColumnNames) {
List<String> result = new LinkedList<>(columnNames);
if (null != generatedKey && generatedKey.isGenerated()) {
result.add(generatedKey.getColumnName());
}
result.addAll(derivedColumnNames);
return result;
checkDuplicateKeyForShardingKey(shardingRule, sqlStatement, sqlStatement.getTable().getTableName());
Optional<GeneratedKey> generatedKey = GeneratedKey.getGenerateKey(shardingRule, tableMetas, parameters, sqlStatement);
List<String> columnNames = getColumnNames(tableMetas, sqlStatement, generatedKey.orNull());
List<ShardingCondition> shardingConditions = new InsertClauseShardingConditionEngine(shardingRule).createShardingConditions(sqlStatement, parameters, columnNames, generatedKey.orNull());
int derivedColumnsCount = getDerivedColumnsCount(shardingRule, sqlStatement.getTable().getTableName(), generatedKey.isPresent() && generatedKey.get().isGenerated());
List<InsertValue> insertValues = getInsertValues(parameters, sqlStatement, derivedColumnsCount);
return new ShardingInsertOptimizedStatement(sqlStatement, shardingConditions, columnNames, generatedKey.orNull(), insertValues);
}

private void checkDuplicateKeyForShardingKey(final ShardingRule shardingRule, final InsertStatement sqlStatement, final String tableName) {
Expand All @@ -98,4 +72,29 @@ private boolean isUpdateShardingKey(final ShardingRule shardingRule, final OnDup
}
return false;
}

private List<String> getColumnNames(final TableMetas tableMetas, final InsertStatement sqlStatement, final GeneratedKey generatedKey) {
List<String> result = sqlStatement.useDefaultColumns() ? tableMetas.getAllColumnNames(sqlStatement.getTable().getTableName()) : sqlStatement.getColumnNames();
if (null != generatedKey && generatedKey.isGenerated()) {
result.remove(generatedKey.getColumnName());
result.add(generatedKey.getColumnName());
}
return result;
}

private int getDerivedColumnsCount(final ShardingRule shardingRule, final String tableName, final boolean isGeneratedValue) {
int encryptDerivedColumnsCount = shardingRule.getEncryptRule().getAssistedQueryAndPlainColumns(tableName).size();
return isGeneratedValue ? encryptDerivedColumnsCount + 1 : encryptDerivedColumnsCount;
}

private List<InsertValue> getInsertValues(final List<Object> parameters, final InsertStatement sqlStatement, final int derivedColumnsCount) {
List<InsertValue> result = new LinkedList<>();
int parametersOffset = 0;
for (Collection<ExpressionSegment> each : sqlStatement.getAllValueExpressions()) {
InsertValue insertValue = new InsertValue(each, derivedColumnsCount, parameters, parametersOffset);
result.add(insertValue);
parametersOffset += insertValue.getParametersCount();
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

import lombok.Getter;
import lombok.ToString;
import org.apache.shardingsphere.core.rule.DataNode;
import org.apache.shardingsphere.core.strategy.route.value.RouteValue;

import java.util.Collection;
import java.util.LinkedList;
import java.util.List;

Expand All @@ -35,4 +37,6 @@
public class ShardingCondition {

private final List<RouteValue> routeValues = new LinkedList<>();

private final Collection<DataNode> dataNodes = new LinkedList<>();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.ToString;
import org.apache.shardingsphere.core.metadata.table.TableMetas;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.ExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.LiteralExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
import org.apache.shardingsphere.core.parse.sql.statement.dml.InsertStatement;
import org.apache.shardingsphere.core.rule.ShardingRule;

import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

Expand All @@ -54,60 +54,49 @@ public final class GeneratedKey {
* Get generate key.
*
* @param shardingRule sharding rule
* @param tableMetas table metas
* @param parameters SQL parameters
* @param insertStatement insert statement
* @param columnNames column names
* @return generate key
*/
public static Optional<GeneratedKey> getGenerateKey(final ShardingRule shardingRule,
final List<Object> parameters, final InsertStatement insertStatement, final Collection<String> columnNames) {
public static Optional<GeneratedKey> getGenerateKey(final ShardingRule shardingRule, final TableMetas tableMetas, final List<Object> parameters, final InsertStatement insertStatement) {
Optional<String> generateKeyColumnName = shardingRule.findGenerateKeyColumnName(insertStatement.getTable().getTableName());
if (!generateKeyColumnName.isPresent()) {
return Optional.absent();
}
return containsGenerateKey(columnNames, insertStatement.getValueCountForPerGroup(), generateKeyColumnName.get())
? findGeneratedKey(parameters, insertStatement, columnNames, generateKeyColumnName.get())
: Optional.of(createGeneratedKey(shardingRule, insertStatement, generateKeyColumnName.get()));
return Optional.of(containsGenerateKey(tableMetas, insertStatement, generateKeyColumnName.get())
? findGeneratedKey(tableMetas, parameters, insertStatement, generateKeyColumnName.get()) : createGeneratedKey(shardingRule, insertStatement, generateKeyColumnName.get()));
}

private static boolean containsGenerateKey(final Collection<String> columnNames, final int valueCountForPerGroup, final String generateKeyColumnName) {
return columnNames.contains(generateKeyColumnName) && columnNames.size() == valueCountForPerGroup;
private static boolean containsGenerateKey(final TableMetas tableMetas, final InsertStatement insertStatement, final String generateKeyColumnName) {
return insertStatement.getColumnNames().isEmpty()
? tableMetas.getAllColumnNames(insertStatement.getTable().getTableName()).size() == insertStatement.getValueCountForPerGroup()
: insertStatement.getColumnNames().contains(generateKeyColumnName);
}

private static Optional<GeneratedKey> findGeneratedKey(
final List<Object> parameters, final InsertStatement insertStatement, final Collection<String> columnNames, final String generateKeyColumnName) {
GeneratedKey result = null;
for (ExpressionSegment each : findGenerateKeyExpressions(insertStatement, columnNames, generateKeyColumnName)) {
if (null == result) {
result = new GeneratedKey(generateKeyColumnName, false);
}
private static GeneratedKey findGeneratedKey(final TableMetas tableMetas, final List<Object> parameters, final InsertStatement insertStatement, final String generateKeyColumnName) {
GeneratedKey result = new GeneratedKey(generateKeyColumnName, false);
for (ExpressionSegment each : findGenerateKeyExpressions(tableMetas, insertStatement, generateKeyColumnName)) {
if (each instanceof ParameterMarkerExpressionSegment) {
result.getGeneratedValues().add((Comparable<?>) parameters.get(((ParameterMarkerExpressionSegment) each).getParameterMarkerIndex()));
} else if (each instanceof LiteralExpressionSegment) {
result.getGeneratedValues().add((Comparable<?>) ((LiteralExpressionSegment) each).getLiterals());
}
}
return Optional.fromNullable(result);
return result;
}

private static Collection<ExpressionSegment> findGenerateKeyExpressions(final InsertStatement insertStatement, final Collection<String> columnNames, final String generateKeyColumnName) {
private static Collection<ExpressionSegment> findGenerateKeyExpressions(final TableMetas tableMetas, final InsertStatement insertStatement, final String generateKeyColumnName) {
Collection<ExpressionSegment> result = new LinkedList<>();
for (Collection<ExpressionSegment> each : insertStatement.getAllValueExpressions()) {
Optional<ExpressionSegment> generateKeyExpression = findGenerateKeyExpression(columnNames.iterator(), generateKeyColumnName, each);
if (generateKeyExpression.isPresent()) {
result.add(generateKeyExpression.get());
}
for (List<ExpressionSegment> each : insertStatement.getAllValueExpressions()) {
result.add(each.get(findGenerateKeyIndex(tableMetas, insertStatement, generateKeyColumnName.toLowerCase())));
}
return result;
}

private static Optional<ExpressionSegment> findGenerateKeyExpression(final Iterator<String> columnNames, final String generateKeyColumnName, final Collection<ExpressionSegment> expressions) {
for (ExpressionSegment each : expressions) {
if (generateKeyColumnName.equalsIgnoreCase(columnNames.next())) {
return Optional.of(each);
}
}
return Optional.absent();
private static int findGenerateKeyIndex(final TableMetas tableMetas, final InsertStatement insertStatement, final String generateKeyColumnName) {
return insertStatement.getColumnNames().isEmpty()
? tableMetas.getAllColumnNames(insertStatement.getTable().getTableName()).indexOf(generateKeyColumnName) : insertStatement.getColumnNames().indexOf(generateKeyColumnName);
}

private static GeneratedKey createGeneratedKey(final ShardingRule shardingRule, final InsertStatement insertStatement, final String generateKeyColumnName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.apache.shardingsphere.core.parse.sql.statement.SQLStatement;

import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
Expand All @@ -49,13 +48,15 @@ public final class ShardingInsertOptimizedStatement extends ShardingConditionOpt

private final GeneratedKey generatedKey;

private final List<InsertValue> insertValues = new LinkedList<>();
private final List<InsertValue> insertValues;

public ShardingInsertOptimizedStatement(final SQLStatement sqlStatement, final List<ShardingCondition> shardingConditions, final List<String> columnNames, final GeneratedKey generatedKey) {
public ShardingInsertOptimizedStatement(final SQLStatement sqlStatement, final List<ShardingCondition> shardingConditions,
final List<String> columnNames, final GeneratedKey generatedKey, final List<InsertValue> insertValues) {
super(sqlStatement, new ShardingConditions(shardingConditions), new EncryptConditions(Collections.<EncryptCondition>emptyList()));
tables = new Tables(sqlStatement);
this.columnNames = columnNames;
this.generatedKey = generatedKey;
this.insertValues = insertValues;
}

/**
Expand Down
Loading

0 comments on commit 6d2a110

Please sign in to comment.