Skip to content

Commit

Permalink
Remove ConnectionContext parameter in sql router and use queryContext…
Browse files Browse the repository at this point in the history
…#getConnectonContext (#33065)

* Remove ConnectionContext parameter in sql router and use queryContext#getConnectonContext

* fix checkstyle

* disable shadow test

* fix metrics e2e error

* fix unit test
  • Loading branch information
strongduanmu authored Sep 30, 2024
1 parent d1932da commit 2221919
Show file tree
Hide file tree
Showing 31 changed files with 112 additions and 131 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public final class SQLRouteCountAdvice extends AbstractInstanceMethodAdvice {

@Override
public void beforeMethod(final TargetAdviceObject target, final TargetAdviceMethod method, final Object[] args, final String pluginType) {
QueryContext queryContext = (QueryContext) args[1];
QueryContext queryContext = (QueryContext) args[0];
SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement();
getSQLType(sqlStatement).ifPresent(optional -> MetricsCollectorRegistry.<CounterMetricsCollector>get(config, pluginType).inc(optional));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void assertSelectRoute() {
}

void assertRoute(final QueryContext queryContext, final String expected) {
advice.beforeMethod(new TargetAdviceObjectFixture(), mock(TargetAdviceMethod.class), new Object[]{new ConnectionContext(Collections::emptySet), queryContext}, "FIXTURE");
advice.beforeMethod(new TargetAdviceObjectFixture(), mock(TargetAdviceMethod.class), new Object[]{queryContext, new ConnectionContext(Collections::emptySet)}, "FIXTURE");
assertThat(MetricsCollectorRegistry.get(config, "FIXTURE").toString(), is(expected));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
Expand Down Expand Up @@ -66,15 +65,15 @@ public final class BroadcastSQLRouter implements EntranceSQLRouter<BroadcastRule

@Override
public RouteContext createRouteContext(final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ShardingSphereDatabase database,
final BroadcastRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
final BroadcastRule rule, final ConfigurationProperties props) {
RouteContext result = new RouteContext();
BroadcastRouteEngineFactory.newInstance(rule, database, queryContext, connectionContext).route(result, rule);
BroadcastRouteEngineFactory.newInstance(rule, database, queryContext).route(result, rule);
return result;
}

@Override
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database, final BroadcastRule broadcastRule,
final ConfigurationProperties props, final ConnectionContext connectionContext) {
final ConfigurationProperties props) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof TCLStatement) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,17 @@ public final class BroadcastRouteEngineFactory {
* @param broadcastRule broadcast rule
* @param database database
* @param queryContext query context
* @param connectionContext connection context
* @return broadcast route engine
*/
public static BroadcastRouteEngine newInstance(final BroadcastRule broadcastRule, final ShardingSphereDatabase database, final QueryContext queryContext,
final ConnectionContext connectionContext) {
public static BroadcastRouteEngine newInstance(final BroadcastRule broadcastRule, final ShardingSphereDatabase database, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
SQLStatement sqlStatement = sqlStatementContext.getSqlStatement();
if (sqlStatement instanceof TCLStatement) {
return new BroadcastDatabaseBroadcastRoutingEngine();
}
if (sqlStatement instanceof DDLStatement) {
if (sqlStatementContext instanceof CursorAvailable) {
return getCursorRouteEngine(broadcastRule, sqlStatementContext, connectionContext);
return getCursorRouteEngine(broadcastRule, sqlStatementContext, queryContext.getConnectionContext());
}
return getDDLRoutingEngine(broadcastRule, database, queryContext);
}
Expand All @@ -84,7 +82,7 @@ public static BroadcastRouteEngine newInstance(final BroadcastRule broadcastRule
if (sqlStatement instanceof DCLStatement) {
return getDCLRoutingEngine(broadcastRule, queryContext);
}
return getDQLRoutingEngine(broadcastRule, queryContext, connectionContext);
return getDQLRoutingEngine(broadcastRule, queryContext);
}

private static BroadcastRouteEngine getCursorRouteEngine(final BroadcastRule broadcastRule, final SQLStatementContext sqlStatementContext, final ConnectionContext connectionContext) {
Expand Down Expand Up @@ -159,12 +157,12 @@ private static boolean isDCLForSingleTable(final SQLStatementContext sqlStatemen
return false;
}

private static BroadcastRouteEngine getDQLRoutingEngine(final BroadcastRule broadcastRule, final QueryContext queryContext, final ConnectionContext connectionContext) {
private static BroadcastRouteEngine getDQLRoutingEngine(final BroadcastRule broadcastRule, final QueryContext queryContext) {
SQLStatementContext sqlStatementContext = queryContext.getSqlStatementContext();
Collection<String> tableNames = sqlStatementContext instanceof TableAvailable ? ((TableAvailable) sqlStatementContext).getTablesContext().getTableNames() : Collections.emptyList();
if (broadcastRule.isAllBroadcastTables(tableNames)) {
return sqlStatementContext.getSqlStatement() instanceof SelectStatement
? new BroadcastUnicastRoutingEngine(sqlStatementContext, tableNames, connectionContext)
? new BroadcastUnicastRoutingEngine(sqlStatementContext, tableNames, queryContext.getConnectionContext())
: new BroadcastDatabaseBroadcastRoutingEngine();
}
return new BroadcastIgnoreRoutingEngine();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ void assertNewInstanceWithTCLStatement() {
SQLStatementContext sqlStatementContext = mock(SQLStatementContext.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(mock(TCLStatement.class));
when(queryContext.getSqlStatementContext()).thenReturn(sqlStatementContext);
BroadcastRouteEngine engine = BroadcastRouteEngineFactory.newInstance(broadcastRule, database, queryContext, connectionContext);
when(queryContext.getConnectionContext()).thenReturn(connectionContext);
BroadcastRouteEngine engine = BroadcastRouteEngineFactory.newInstance(broadcastRule, database, queryContext);
assertThat(engine, instanceOf(BroadcastDatabaseBroadcastRoutingEngine.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void assertCreateBroadcastRouteContextWithMultiDataSource() throws SQLException
when(currentConfig.getTables()).thenReturn(Collections.singleton("t_order"));
BroadcastRule broadcastRule = new BroadcastRule(currentConfig, DefaultDatabase.LOGIC_NAME, createMultiDataSourceMap(), Collections.emptyList());
RouteContext routeContext = new BroadcastSQLRouter().createRouteContext(createQueryContext(), mock(RuleMetaData.class), mockDatabaseWithMultipleResources(), broadcastRule,
new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
new ConfigurationProperties(new Properties()));
List<RouteUnit> routeUnits = new ArrayList<>(routeContext.getRouteUnits());
assertThat(routeContext.getRouteUnits().size(), is(2));
assertThat(routeUnits.get(0).getDataSourceMapper().getLogicName(), is(routeUnits.get(0).getDataSourceMapper().getActualName()));
Expand All @@ -96,7 +96,7 @@ void assertCreateBroadcastRouteContextWithSingleDataSource() throws SQLException
broadcastRule.getAttributes().getAttribute(DataNodeRuleAttribute.class).getAllDataNodes().put("t_order", Collections.singletonList(createDataNode("tmp_ds")));
ShardingSphereDatabase database = mockSingleDatabase();
RouteContext routeContext = new BroadcastSQLRouter().createRouteContext(
createQueryContext(), mock(RuleMetaData.class), database, broadcastRule, new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
createQueryContext(), mock(RuleMetaData.class), database, broadcastRule, new ConfigurationProperties(new Properties()));
assertThat(routeContext.getRouteUnits().size(), is(1));
RouteUnit routeUnit = routeContext.getRouteUnits().iterator().next();
assertThat(routeUnit.getDataSourceMapper().getLogicName(), is("tmp_ds"));
Expand All @@ -113,7 +113,7 @@ void assertDecorateBroadcastRouteContextWithSingleDataSource() {
routeContext.getRouteUnits().add(new RouteUnit(new RouteMapper("foo_ds", "foo_ds"), Lists.newArrayList()));
BroadcastSQLRouter sqlRouter = (BroadcastSQLRouter) OrderedSPILoader.getServices(SQLRouter.class, Collections.singleton(broadcastRule)).get(broadcastRule);
sqlRouter.decorateRouteContext(routeContext, createQueryContext(),
mockSingleDatabase(), broadcastRule, new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
mockSingleDatabase(), broadcastRule, new ConfigurationProperties(new Properties()));
Iterator<String> routedDataSourceNames = routeContext.getActualDataSourceNames().iterator();
assertThat(routedDataSourceNames.next(), is("foo_ds"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.readwritesplitting.constant.ReadwriteSplittingOrder;
import org.apache.shardingsphere.readwritesplitting.rule.ReadwriteSplittingRule;
Expand All @@ -40,14 +39,15 @@ public final class ReadwriteSplittingSQLRouter implements DecorateSQLRouter<Read

@Override
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database,
final ReadwriteSplittingRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
final ReadwriteSplittingRule rule, final ConfigurationProperties props) {
Collection<RouteUnit> toBeRemoved = new LinkedList<>();
Collection<RouteUnit> toBeAdded = new LinkedList<>();
for (RouteUnit each : routeContext.getRouteUnits()) {
String logicDataSourceName = each.getDataSourceMapper().getActualName();
rule.findDataSourceGroupRule(logicDataSourceName).ifPresent(optional -> {
toBeRemoved.add(each);
String actualDataSourceName = new ReadwriteSplittingDataSourceRouter(optional, connectionContext).route(queryContext.getSqlStatementContext(), queryContext.getHintValueContext());
String actualDataSourceName =
new ReadwriteSplittingDataSourceRouter(optional, queryContext.getConnectionContext()).route(queryContext.getSqlStatementContext(), queryContext.getHintValueContext());
toBeAdded.add(new RouteUnit(new RouteMapper(logicDataSourceName, actualDataSourceName), each.getTableMappers()));
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.shardingsphere.infra.route.context.RouteMapper;
import org.apache.shardingsphere.infra.route.context.RouteUnit;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.connection.transaction.TransactionConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.infra.spi.type.ordered.OrderedSPILoader;
import org.apache.shardingsphere.readwritesplitting.config.ReadwriteSplittingRuleConfiguration;
Expand Down Expand Up @@ -90,23 +91,25 @@ void assertDecorateRouteContextToPrimaryDataSource() {
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()));
Iterator<String> routedDataSourceNames = actual.getActualDataSourceNames().iterator();
assertThat(routedDataSourceNames.next(), is(NONE_READWRITE_SPLITTING_DATASOURCE_NAME));
assertThat(routedDataSourceNames.next(), is(WRITE_DATASOURCE));
}

@Test
void assertDecorateRouteContextToReplicaDataSource() {
RouteContext actual = mockRouteContext();
MySQLSelectStatement selectStatement = mock(MySQLSelectStatement.class);
when(sqlStatementContext.getSqlStatement()).thenReturn(selectStatement);
when(selectStatement.getLock()).thenReturn(Optional.empty());
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), mockConnectionContext(), mock(ShardingSphereMetaData.class));
ConnectionContext connectionContext = mockConnectionContext();
when(connectionContext.getTransactionContext()).thenReturn(mock(TransactionConnectionContext.class));
QueryContext queryContext = new QueryContext(sqlStatementContext, "", Collections.emptyList(), new HintValueContext(), connectionContext, mock(ShardingSphereMetaData.class));
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
RouteContext actual = mockRouteContext();
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()));
Iterator<String> routedDataSourceNames = actual.getActualDataSourceNames().iterator();
assertThat(routedDataSourceNames.next(), is(NONE_READWRITE_SPLITTING_DATASOURCE_NAME));
assertThat(routedDataSourceNames.next(), is(READ_DATASOURCE));
Expand All @@ -122,7 +125,7 @@ void assertDecorateRouteContextToPrimaryDataSourceWithLock() {
RuleMetaData ruleMetaData = new RuleMetaData(Collections.singleton(staticRule));
ShardingSphereDatabase database = new ShardingSphereDatabase(DefaultDatabase.LOGIC_NAME,
mock(DatabaseType.class), mock(ResourceMetaData.class, RETURNS_DEEP_STUBS), ruleMetaData, Collections.emptyMap());
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()), new ConnectionContext(Collections::emptySet));
sqlRouter.decorateRouteContext(actual, queryContext, database, staticRule, new ConfigurationProperties(new Properties()));
Iterator<String> routedDataSourceNames = actual.getActualDataSourceNames().iterator();
assertThat(routedDataSourceNames.next(), is(NONE_READWRITE_SPLITTING_DATASOURCE_NAME));
assertThat(routedDataSourceNames.next(), is(WRITE_DATASOURCE));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.shardingsphere.infra.metadata.database.ShardingSphereDatabase;
import org.apache.shardingsphere.infra.route.DecorateSQLRouter;
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.session.connection.ConnectionContext;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.shadow.constant.ShadowOrder;
import org.apache.shardingsphere.shadow.route.engine.ShadowRouteEngineFactory;
Expand All @@ -36,7 +35,7 @@ public final class ShadowSQLRouter implements DecorateSQLRouter<ShadowRule> {

@Override
public void decorateRouteContext(final RouteContext routeContext, final QueryContext queryContext, final ShardingSphereDatabase database,
final ShadowRule rule, final ConfigurationProperties props, final ConnectionContext connectionContext) {
final ShadowRule rule, final ConfigurationProperties props) {
ShadowRouteEngineFactory.newInstance(queryContext).route(routeContext, rule);
}

Expand Down
Loading

0 comments on commit 2221919

Please sign in to comment.