Skip to content

Commit

Permalink
Use Multimap and CaseInsensitiveString to replace CaseInsensitiveMap …
Browse files Browse the repository at this point in the history
…for supporting mysql multi table join with same table alias (#33303)

* Use Multimap and CaseInsensitiveString to replace CaseInsensitiveMap for supporting mysql multi table join with same table alias

* remove toLowerCase call

* Add more unit test for ColumnSegmentBinderTest

* Update release note
  • Loading branch information
strongduanmu authored Oct 18, 2024
1 parent 7034564 commit 044e049
Show file tree
Hide file tree
Showing 40 changed files with 331 additions and 241 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
### Bug Fix

1. Mode: Fixes `JDBCRepository` improper handling of H2database in memory mode - [#33281](https://github.com/apache/shardingsphere/issues/33281)
1. SQL Binder: Use Multimap and CaseInsensitiveString to replace CaseInsensitiveMap for supporting mysql multi table join with same table alias - [#33303](https://github.com/apache/shardingsphere/pull/33303)

### Change Log
1. [MILESTONE](https://github.com/apache/shardingsphere/milestone/30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.shardingsphere.infra.binder.engine.segment.assign;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
Expand All @@ -30,7 +32,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
Expand All @@ -49,14 +50,15 @@ public final class AssignmentSegmentBinder {
* @return bound assignment segment
*/
public static SetAssignmentSegment bind(final SetAssignmentSegment segment, final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts, final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
return new SetAssignmentSegment(segment.getStartIndex(), segment.getStopIndex(), segment.getAssignments().stream()
.map(each -> bindColumnAssignmentSegment(each, binderContext, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList()));
}

private static ColumnAssignmentSegment bindColumnAssignmentSegment(final ColumnAssignmentSegment columnAssignmentSegment, final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts,
final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
List<ColumnSegment> boundColumns = columnAssignmentSegment.getColumns().stream()
.map(each -> ColumnSegmentBinder.bind(each, SegmentType.SET_ASSIGNMENT, binderContext, tableBinderContexts, outerTableBinderContexts)).collect(Collectors.toList());
ExpressionSegment boundValue = ExpressionSegmentBinder.bind(columnAssignmentSegment.getValue(), SegmentType.SET_ASSIGNMENT, binderContext, tableBinderContexts, outerTableBinderContexts);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.shardingsphere.infra.binder.engine.segment.column;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
Expand All @@ -27,8 +30,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.InsertColumnsSegment;

import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.stream.Collectors;

/**
Expand All @@ -45,9 +46,10 @@ public final class InsertColumnsSegmentBinder {
* @param tableBinderContexts table binder contexts
* @return bound insert columns segment
*/
public static InsertColumnsSegment bind(final InsertColumnsSegment segment, final SQLStatementBinderContext binderContext, final Map<String, TableSegmentBinderContext> tableBinderContexts) {
public static InsertColumnsSegment bind(final InsertColumnsSegment segment, final SQLStatementBinderContext binderContext,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
Collection<ColumnSegment> boundColumns = segment.getColumns().stream()
.map(each -> ColumnSegmentBinder.bind(each, SegmentType.INSERT_COLUMNS, binderContext, tableBinderContexts, Collections.emptyMap())).collect(Collectors.toList());
.map(each -> ColumnSegmentBinder.bind(each, SegmentType.INSERT_COLUMNS, binderContext, tableBinderContexts, LinkedHashMultimap.create())).collect(Collectors.toList());
return new InsertColumnsSegment(segment.getStartIndex(), segment.getStopIndex(), boundColumns);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

package org.apache.shardingsphere.infra.binder.engine.segment.expression;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
Expand All @@ -38,9 +41,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;

import java.util.LinkedHashMap;
import java.util.Map;

/**
* Expression segment binder.
*/
Expand All @@ -58,15 +58,16 @@ public final class ExpressionSegmentBinder {
* @return bound expression segment
*/
public static ExpressionSegment bind(final ExpressionSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts, final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (segment instanceof BinaryOperationExpression) {
return BinaryOperationExpressionBinder.bind((BinaryOperationExpression) segment, parentSegmentType, binderContext, tableBinderContexts, outerTableBinderContexts);
}
if (segment instanceof ExistsSubqueryExpression) {
return ExistsSubqueryExpressionBinder.bind((ExistsSubqueryExpression) segment, binderContext, tableBinderContexts);
}
if (segment instanceof SubqueryExpressionSegment) {
Map<String, TableSegmentBinderContext> newOuterTableBinderContexts = new LinkedHashMap<>();
Multimap<CaseInsensitiveString, TableSegmentBinderContext> newOuterTableBinderContexts = LinkedHashMultimap.create();
newOuterTableBinderContexts.putAll(outerTableBinderContexts);
newOuterTableBinderContexts.putAll(tableBinderContexts);
return new SubqueryExpressionSegment(SubquerySegmentBinder.bind(((SubqueryExpressionSegment) segment).getSubquery(), binderContext, newOuterTableBinderContexts));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.shardingsphere.infra.binder.engine.segment.expression.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.infra.binder.engine.segment.SegmentType;
Expand All @@ -26,8 +28,6 @@
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
import org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;

import java.util.Map;

/**
* Binary operation expression binder.
*/
Expand All @@ -45,7 +45,8 @@ public final class BinaryOperationExpressionBinder {
* @return bound binary operation expression segment
*/
public static BinaryOperationExpression bind(final BinaryOperationExpression segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts, final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ExpressionSegment boundLeft = ExpressionSegmentBinder.bind(segment.getLeft(), parentSegmentType, binderContext, tableBinderContexts, outerTableBinderContexts);
ExpressionSegment boundRight = ExpressionSegmentBinder.bind(segment.getRight(), parentSegmentType, binderContext, tableBinderContexts, outerTableBinderContexts);
return new BinaryOperationExpression(segment.getStartIndex(), segment.getStopIndex(), boundLeft, boundRight, segment.getOperator(), segment.getText());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.shardingsphere.infra.binder.engine.segment.expression.type;

import com.cedarsoftware.util.CaseInsensitiveMap.CaseInsensitiveString;
import com.google.common.base.Strings;
import com.google.common.collect.Multimap;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.groovy.util.Maps;
Expand Down Expand Up @@ -71,7 +73,8 @@ public final class ColumnSegmentBinder {
* @return bound column segment
*/
public static ColumnSegment bind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts, final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (EXCLUDE_BIND_COLUMNS.contains(segment.getIdentifier().getValue().toUpperCase())) {
return segment;
}
Expand All @@ -93,41 +96,42 @@ private static ColumnSegment copy(final ColumnSegment segment) {

private static Collection<TableSegmentBinderContext> getTableSegmentBinderContexts(final ColumnSegment segment, final SegmentType parentSegmentType,
final SQLStatementBinderContext binderContext,
final Map<String, TableSegmentBinderContext> tableBinderContexts,
final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
if (segment.getOwner().isPresent()) {
String owner = segment.getOwner().get().getIdentifier().getValue().toLowerCase();
return findTableBinderContextByOwner(owner, tableBinderContexts, outerTableBinderContexts, binderContext.getExternalTableBinderContexts())
.map(Collections::singletonList).orElse(Collections.emptyList());
String owner = segment.getOwner().get().getIdentifier().getValue();
return getTableBinderContextByOwner(owner, tableBinderContexts, outerTableBinderContexts, binderContext.getExternalTableBinderContexts());
}
if (!binderContext.getJoinTableProjectionSegments().isEmpty() && isNeedUseJoinTableProjectionBind(segment, parentSegmentType, binderContext)) {
return Collections.singleton(new SimpleTableSegmentBinderContext(binderContext.getJoinTableProjectionSegments()));
}
return tableBinderContexts.values();
}

private static Optional<TableSegmentBinderContext> findTableBinderContextByOwner(final String owner, final Map<String, TableSegmentBinderContext> tableBinderContexts,
final Map<String, TableSegmentBinderContext> outerTableBinderContexts,
final Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
if (tableBinderContexts.containsKey(owner)) {
return Optional.of(tableBinderContexts.get(owner));
private static Collection<TableSegmentBinderContext> getTableBinderContextByOwner(final String owner, final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> externalTableBinderContexts) {
CaseInsensitiveString caseInsensitiveOwner = new CaseInsensitiveString(owner);
if (tableBinderContexts.containsKey(caseInsensitiveOwner)) {
return tableBinderContexts.get(caseInsensitiveOwner);
}
if (outerTableBinderContexts.containsKey(owner)) {
return Optional.of(outerTableBinderContexts.get(owner));
if (outerTableBinderContexts.containsKey(caseInsensitiveOwner)) {
return outerTableBinderContexts.get(caseInsensitiveOwner);
}
if (externalTableBinderContexts.containsKey(owner)) {
return Optional.of(externalTableBinderContexts.get(owner));
if (externalTableBinderContexts.containsKey(caseInsensitiveOwner)) {
return externalTableBinderContexts.get(caseInsensitiveOwner);
}
return Optional.empty();
return Collections.emptyList();
}

private static boolean isNeedUseJoinTableProjectionBind(final ColumnSegment segment, final SegmentType parentSegmentType, final SQLStatementBinderContext binderContext) {
return SegmentType.PROJECTION == parentSegmentType
|| SegmentType.PREDICATE == parentSegmentType && binderContext.getUsingColumnNames().contains(segment.getIdentifier().getValue().toLowerCase());
|| SegmentType.PREDICATE == parentSegmentType && binderContext.getUsingColumnNames().contains(segment.getIdentifier().getValue());
}

private static Optional<ColumnSegment> findInputColumnSegment(final ColumnSegment segment, final SegmentType parentSegmentType, final Collection<TableSegmentBinderContext> tableBinderContexts,
final Map<String, TableSegmentBinderContext> outerTableBinderContexts, final SQLStatementBinderContext binderContext) {
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts,
final SQLStatementBinderContext binderContext) {
ColumnSegment result = null;
boolean isFindInputColumn = false;
for (TableSegmentBinderContext each : tableBinderContexts) {
Expand Down Expand Up @@ -172,13 +176,14 @@ private static Optional<ColumnSegment> findInputColumnSegmentByPivotColumns(fina
if (pivotColumnNames.isEmpty()) {
return Optional.empty();
}
if (pivotColumnNames.contains(segment.getIdentifier().getValue().toLowerCase())) {
if (pivotColumnNames.contains(segment.getIdentifier().getValue())) {
return Optional.of(new ColumnSegment(0, 0, segment.getIdentifier()));
}
return Optional.empty();
}

private static Optional<ProjectionSegment> findInputColumnSegmentFromOuterTable(final ColumnSegment segment, final Map<String, TableSegmentBinderContext> outerTableBinderContexts) {
private static Optional<ProjectionSegment> findInputColumnSegmentFromOuterTable(final ColumnSegment segment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> outerTableBinderContexts) {
ListIterator<TableSegmentBinderContext> listIterator = new ArrayList<>(outerTableBinderContexts.values()).listIterator(outerTableBinderContexts.size());
while (listIterator.hasPrevious()) {
TableSegmentBinderContext each = listIterator.previous();
Expand All @@ -190,7 +195,8 @@ private static Optional<ProjectionSegment> findInputColumnSegmentFromOuterTable(
return Optional.empty();
}

private static Optional<ProjectionSegment> findInputColumnSegmentFromExternalTables(final ColumnSegment segment, final Map<String, TableSegmentBinderContext> externalTableBinderContexts) {
private static Optional<ProjectionSegment> findInputColumnSegmentFromExternalTables(final ColumnSegment segment,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> externalTableBinderContexts) {
for (TableSegmentBinderContext each : externalTableBinderContexts.values()) {
Optional<ProjectionSegment> result = each.findProjectionSegmentByColumnLabel(segment.getIdentifier().getValue());
if (result.isPresent()) {
Expand All @@ -204,7 +210,7 @@ private static Optional<ColumnSegment> findInputColumnSegmentByVariables(final C
if (variableNames.isEmpty()) {
return Optional.empty();
}
if (variableNames.contains(segment.getIdentifier().getValue().toLowerCase())) {
if (variableNames.contains(segment.getIdentifier().getValue())) {
ColumnSegment result = new ColumnSegment(0, 0, segment.getIdentifier());
result.setVariable(true);
return Optional.of(result);
Expand Down Expand Up @@ -246,7 +252,8 @@ private static ColumnSegmentBoundInfo createColumnSegmentBoundInfo(final ColumnS
* @param tableBinderContexts table binder contexts
* @return bound using column segment
*/
public static ColumnSegment bindUsingColumn(final ColumnSegment segment, final SegmentType parentSegmentType, final Map<String, TableSegmentBinderContext> tableBinderContexts) {
public static ColumnSegment bindUsingColumn(final ColumnSegment segment, final SegmentType parentSegmentType,
final Multimap<CaseInsensitiveString, TableSegmentBinderContext> tableBinderContexts) {
ColumnSegment result = copy(segment);
List<ColumnSegment> usingInputColumnSegments = findUsingInputColumnSegments(segment.getIdentifier().getValue(), tableBinderContexts.values());
ShardingSpherePreconditions.checkState(usingInputColumnSegments.size() >= 2,
Expand Down
Loading

0 comments on commit 044e049

Please sign in to comment.