Skip to content

Commit

Permalink
Merge branch 'StormAll-210'
Browse files Browse the repository at this point in the history
fixed #210 Routing to single table reomve derived SQL
  • Loading branch information
gaohongtao committed Dec 20, 2016
2 parents ab78772 + 255d422 commit f0e78f4
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
package com.dangdang.ddframe.rdb.sharding.parser.result.router;

import com.google.common.base.Joiner;
import lombok.AccessLevel;
import lombok.Getter;

import java.io.IOException;
import java.util.Collection;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
Expand All @@ -35,6 +35,8 @@
*/
public class SQLBuilder implements Appendable {

private final List<SQLBuilder> derivedSQLBuilders = new ArrayList<>();

private final List<Object> segments;

private final Map<String, StringToken> tokenMap;
Expand All @@ -46,6 +48,11 @@ public class SQLBuilder implements Appendable {
@Getter
private boolean changed;

@Getter(AccessLevel.PRIVATE)
private boolean removeDerivedSQLToken;

private boolean hasExistedDerivedSQLToken;

public SQLBuilder() {
segments = new LinkedList<>();
tokenMap = new HashMap<>();
Expand Down Expand Up @@ -83,24 +90,41 @@ public void appendToken(final String label, final String token) {
stringToken.label = label;
stringToken.value = token;
tokenMap.put(label, stringToken);
stringToken.listeners.add(this);
}
stringToken.indices.add(segments.size());
segments.add(stringToken);
currentSegment = new StringBuilder();
segments.add(currentSegment);
}

/**
* 用实际的值替代占位符.
*
* @param label 占位符
* @param token 实际的值
*/
public void buildSQL(final String label, final String token) {
if (tokenMap.containsKey(label)) {
tokenMap.get(label).setValue(token);
buildSQL(label, token, false);
}

/**
* 用实际的值替代占位符,并可以标记该SQL是否为派生SQL.
*
* @param label 占位符
* @param token 实际的值
* @param isDerived 是否是派生的SQL
*/
public void buildSQL(final String label, final String token, final boolean isDerived) {
if (!tokenMap.containsKey(label)) {
return;
}
if (isDerived) {
hasExistedDerivedSQLToken = true;
}
StringToken labelSQL = tokenMap.get(label);
labelSQL.isDerived = isDerived;
labelSQL.value = token;
changeState();
}

/**
Expand Down Expand Up @@ -134,9 +158,7 @@ public SQLBuilder buildSQLWithNewToken() {
result.segments.set(index, each);
}
}
for (StringToken each : result.tokenMap.values()) {
each.listeners.add(result);
}
derivedSQLBuilders.add(result);
newTokenList.clear();
return result;
}
Expand Down Expand Up @@ -173,15 +195,28 @@ public Appendable append(final char c) throws IOException {
changeState();
return this;
}

private void changeState() {
changed = true;
for (SQLBuilder each : derivedSQLBuilders) {
each.changeState();
}
}

private void clearState() {
changed = false;
}

/**
* 移除衍生的SQL片段.
*/
public void removeDerivedSQL() {
if (hasExistedDerivedSQLToken) {
removeDerivedSQLToken = true;
changeState();
}
}

@Override
public String toString() {
StringBuilder result = new StringBuilder();
Expand All @@ -200,29 +235,26 @@ private class StringToken {
private String label;

private String value;

private boolean isDerived;

private final List<Integer> indices = new LinkedList<>();

private final Collection<SQLBuilder> listeners = new HashSet<>();

public void setValue(final String value) {
this.value = value;
for (SQLBuilder each : listeners) {
each.changeState();
}
}

String toToken() {
if (null == value) {
if (isEmptyValueOutput()) {
return "";
}
Joiner joiner = Joiner.on("");
return label.equals(value) ? joiner.join("[Token(", value, ")]") : joiner.join("[", label, "(", value, ")]");
}

private boolean isEmptyValueOutput() {
return null == value || isDerived && isRemoveDerivedSQLToken();
}

@Override
public String toString() {
return null == value ? "" : value;
return isEmptyValueOutput() ? "" : value;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public void endVisit(final MySqlSelectQueryBlock x) {
appendSortableColumn(derivedSelectItems, getParseContext().getParsedResult().getMergeContext().getGroupByColumns());
appendSortableColumn(derivedSelectItems, getParseContext().getParsedResult().getMergeContext().getOrderByColumns());
if (0 != derivedSelectItems.length()) {
getSQLBuilder().buildSQL(getParseContext().getAutoGenTokenKey(), derivedSelectItems.toString());
getSQLBuilder().buildSQL(getParseContext().getAutoGenTokenKey(), derivedSelectItems.toString(), true);
}
super.endVisit(x);
stepOutQuery();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.dangdang.ddframe.rdb.sharding.parser.result.SQLParsedResult;
import com.dangdang.ddframe.rdb.sharding.parser.result.merger.Limit;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.ConditionContext;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.SQLBuilder;
import com.dangdang.ddframe.rdb.sharding.parser.result.router.Table;
import com.dangdang.ddframe.rdb.sharding.router.binding.BindingTablesRouter;
import com.dangdang.ddframe.rdb.sharding.router.mixed.MixedTablesRouter;
Expand Down Expand Up @@ -93,12 +94,8 @@ SQLRouteResult routeSQL(final SQLParsedResult parsedResult, final List<Object> p
RoutingResult routingResult = routeSQL(each, parsedResult);
result.getExecutionUnits().addAll(routingResult.getSQLExecutionUnits(parsedResult.getRouteContext().getSqlBuilder()));
}
amendSQLAccordingToRouteResult(parsedResult, parameters, result);
MetricsContext.stop(context);
Limit limit = result.getMergeContext().getLimit();
if (null != limit) {
limit.replaceSQL(parsedResult.getRouteContext().getSqlBuilder(), result.getExecutionUnits().size() > 1);
limit.replaceParameters(parameters, result.getExecutionUnits().size() > 1);
}
log.debug("final route result is {} target", result.getExecutionUnits().size());
for (SQLExecutionUnit each : result.getExecutionUnits()) {
log.debug("{}:{} {}", each.getDataSource(), each.getSql(), parameters);
Expand All @@ -124,4 +121,17 @@ public String apply(final Table input) {
// TODO 可配置是否执行笛卡尔积
return new MixedTablesRouter(shardingRule, logicTables, conditionContext, parsedResult.getRouteContext().getSqlStatementType()).route();
}

private void amendSQLAccordingToRouteResult(final SQLParsedResult parsedResult, final List<Object> parameters, final SQLRouteResult result) {
boolean isVarious = result.getExecutionUnits().size() > 1;
Limit limit = result.getMergeContext().getLimit();
SQLBuilder sqlBuilder = parsedResult.getRouteContext().getSqlBuilder();
if (null != limit) {
limit.replaceSQL(sqlBuilder, isVarious);
limit.replaceParameters(parameters, isVarious);
}
if (!isVarious) {
sqlBuilder.removeDerivedSQL();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@
import static org.junit.Assert.assertThat;

public final class SelectSingleTableTest extends AbstractDynamicRouteSqlTest {

@Test
public void assertGroupBy() throws SQLParserException {
assertSingleTarget("select sum(qty) from order where order_id = 1 group by tenant_id", "ds_1",
"SELECT SUM(qty) FROM order_1 WHERE order_id = 1 GROUP BY tenant_id");
assertMultipleTargets("select sum(qty) from order group by tenant_id", 4, Arrays.asList("ds_0", "ds_1"),
Arrays.asList("SELECT SUM(qty), tenant_id AS sharding_gen_1 FROM order_0 GROUP BY tenant_id"));
}

@Test
public void assertSingleSelect() throws SQLParserException {
Expand Down

0 comments on commit f0e78f4

Please sign in to comment.