From 2cfd72981a16e1eaef88cbbe86e0fe4d1e751cff Mon Sep 17 00:00:00 2001 From: duanzhengqiang Date: Wed, 23 Oct 2024 14:23:29 +0800 Subject: [PATCH] Fix SQL performance issues caused by repeated subquery fetches --- .../EncryptProjectionTokenGenerator.java | 2 +- .../context/segment/table/TablesContext.java | 3 + .../statement/dml/SelectStatementContext.java | 2 +- .../core/util/SubqueryExtractUtils.java | 113 ++++++++++-------- .../core/util/WhereExtractUtils.java | 16 +-- .../core/util/SubqueryExtractUtilsTest.java | 14 +-- 6 files changed, 74 insertions(+), 76 deletions(-) diff --git a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java index b870cf68eb5f1..f8c595ee5943c 100644 --- a/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java +++ b/features/encrypt/core/src/main/java/org/apache/shardingsphere/encrypt/rewrite/token/generator/projection/EncryptProjectionTokenGenerator.java @@ -69,7 +69,7 @@ public final class EncryptProjectionTokenGenerator { */ public Collection generateSQLTokens(final SelectStatementContext selectStatementContext) { Collection result = new LinkedHashSet<>(generateSelectSQLTokens(selectStatementContext)); - selectStatementContext.getSubqueryContexts().values().stream().map(this::generateSelectSQLTokens).forEach(result::addAll); + selectStatementContext.getSubqueryContexts().values().stream().map(this::generateSQLTokens).forEach(result::addAll); return result; } diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/table/TablesContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/table/TablesContext.java index e656b6f476261..0d58c9bbb2afe 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/table/TablesContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/segment/table/TablesContext.java @@ -106,6 +106,9 @@ private Optional findDatabaseName(final SimpleTableSegment tableSegment, } private Map> createSubqueryTables(final Map subqueryContexts, final SubqueryTableSegment subqueryTable) { + if (!subqueryContexts.containsKey(subqueryTable.getSubquery().getStartIndex())) { + return Collections.emptyMap(); + } SelectStatementContext subqueryContext = subqueryContexts.get(subqueryTable.getSubquery().getStartIndex()); Map subqueryTableContexts = new SubqueryTableContextEngine().createSubqueryTableContexts(subqueryContext, subqueryTable.getAliasName().orElse(null)); Map> result = new HashMap<>(subqueryTableContexts.size(), 1F); diff --git a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java index e511b6af62ebf..2fe420260efe8 100644 --- a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java +++ b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/context/statement/dml/SelectStatementContext.java @@ -187,7 +187,7 @@ private Collection getTableMapperRuleAttributes(final private Map createSubqueryContexts(final ShardingSphereMetaData metaData, final List params, final String currentDatabaseName, final Collection tableSegments) { - Collection subquerySegments = SubqueryExtractUtils.getSubquerySegments(getSqlStatement()); + Collection subquerySegments = SubqueryExtractUtils.getSubquerySegments(getSqlStatement(), false); Map result = new HashMap<>(subquerySegments.size(), 1F); for (SubquerySegment each : subquerySegments) { SelectStatementContext subqueryContext = new SelectStatementContext(metaData, params, each.getSelect(), currentDatabaseName, tableSegments); diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java index ef6a46260f3b7..13716c2a2d79e 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java @@ -60,37 +60,44 @@ public final class SubqueryExtractUtils { * Get subquery segment from SelectStatement. * * @param selectStatement SelectStatement + * @param needRecursive need recursive * @return subquery segment collection */ - public static Collection getSubquerySegments(final SelectStatement selectStatement) { + public static Collection getSubquerySegments(final SelectStatement selectStatement, final boolean needRecursive) { List result = new LinkedList<>(); - extractSubquerySegments(result, selectStatement); + extractSubquerySegments(result, selectStatement, needRecursive); return result; } - private static void extractSubquerySegments(final List result, final SelectStatement selectStatement) { - extractSubquerySegmentsFromProjections(result, selectStatement.getProjections()); - selectStatement.getFrom().ifPresent(optional -> extractSubquerySegmentsFromTableSegment(result, optional)); + private static void extractSubquerySegments(final List result, final SelectStatement selectStatement, final boolean needRecursive) { + extractSubquerySegmentsFromProjections(result, selectStatement.getProjections(), needRecursive); + selectStatement.getFrom().ifPresent(optional -> extractSubquerySegmentsFromTableSegment(result, optional, needRecursive)); if (selectStatement.getWhere().isPresent()) { - extractSubquerySegmentsFromWhere(result, selectStatement.getWhere().get().getExpr()); + extractSubquerySegmentsFromWhere(result, selectStatement.getWhere().get().getExpr(), needRecursive); } if (selectStatement.getCombine().isPresent()) { - extractSubquerySegmentsFromCombine(result, selectStatement.getCombine().get()); + extractSubquerySegmentsFromCombine(result, selectStatement.getCombine().get(), needRecursive); } if (selectStatement.getWithSegment().isPresent()) { - extractSubquerySegmentsFromCTEs(result, selectStatement.getWithSegment().get().getCommonTableExpressions()); + extractSubquerySegmentsFromCTEs(result, selectStatement.getWithSegment().get().getCommonTableExpressions(), needRecursive); } } - private static void extractSubquerySegmentsFromCTEs(final List result, final Collection withSegment) { + private static void extractSubquerySegmentsFromCTEs(final List result, final Collection withSegment, final boolean needRecursive) { for (CommonTableExpressionSegment each : withSegment) { each.getSubquery().setSubqueryType(SubqueryType.WITH); result.add(each.getSubquery()); - extractSubquerySegments(result, each.getSubquery().getSelect()); + extractRecursive(needRecursive, result, each.getSubquery().getSelect()); } } - private static void extractSubquerySegmentsFromProjections(final List result, final ProjectionsSegment projections) { + private static void extractRecursive(final boolean needRecursive, final List result, final SelectStatement select) { + if (needRecursive) { + extractSubquerySegments(result, select, true); + } + } + + private static void extractSubquerySegmentsFromProjections(final List result, final ProjectionsSegment projections, final boolean needRecursive) { if (null == projections || projections.getProjections().isEmpty()) { return; } @@ -99,112 +106,114 @@ private static void extractSubquerySegmentsFromProjections(final List result, final TableSegment tableSegment) { + private static void extractSubquerySegmentsFromTableSegment(final List result, final TableSegment tableSegment, final boolean needRecursive) { if (tableSegment instanceof SubqueryTableSegment) { - extractSubquerySegmentsFromSubqueryTableSegment(result, (SubqueryTableSegment) tableSegment); + extractSubquerySegmentsFromSubqueryTableSegment(result, (SubqueryTableSegment) tableSegment, needRecursive); } if (tableSegment instanceof JoinTableSegment) { - extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getLeft()); - extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getRight()); + extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getLeft(), needRecursive); + extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getRight(), needRecursive); } } - private static void extractSubquerySegmentsFromJoinTableSegment(final List result, final TableSegment tableSegment) { + private static void extractSubquerySegmentsFromJoinTableSegment(final List result, final TableSegment tableSegment, final boolean needRecursive) { if (tableSegment instanceof SubqueryTableSegment) { SubquerySegment subquery = ((SubqueryTableSegment) tableSegment).getSubquery(); subquery.setSubqueryType(SubqueryType.JOIN); result.add(subquery); - extractSubquerySegments(result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect()); } else if (tableSegment instanceof JoinTableSegment) { - extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getLeft()); - extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getRight()); + extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getLeft(), needRecursive); + extractSubquerySegmentsFromJoinTableSegment(result, ((JoinTableSegment) tableSegment).getRight(), needRecursive); } } - private static void extractSubquerySegmentsFromSubqueryTableSegment(final List result, final SubqueryTableSegment subqueryTableSegment) { + private static void extractSubquerySegmentsFromSubqueryTableSegment(final List result, final SubqueryTableSegment subqueryTableSegment, final boolean needRecursive) { SubquerySegment subquery = subqueryTableSegment.getSubquery(); subquery.setSubqueryType(SubqueryType.TABLE); result.add(subquery); - extractSubquerySegments(result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect()); } - private static void extractSubquerySegmentsFromWhere(final List result, final ExpressionSegment expressionSegment) { - extractSubquerySegmentsFromExpression(result, expressionSegment, SubqueryType.PREDICATE); + private static void extractSubquerySegmentsFromWhere(final List result, final ExpressionSegment expressionSegment, final boolean needRecursive) { + extractSubquerySegmentsFromExpression(result, expressionSegment, SubqueryType.PREDICATE, needRecursive); } - private static void extractSubquerySegmentsFromExpression(final List result, final ExpressionSegment expressionSegment, final SubqueryType subqueryType) { + private static void extractSubquerySegmentsFromExpression(final List result, final ExpressionSegment expressionSegment, final SubqueryType subqueryType, + final boolean needRecursive) { if (expressionSegment instanceof SubqueryExpressionSegment) { SubquerySegment subquery = ((SubqueryExpressionSegment) expressionSegment).getSubquery(); subquery.setSubqueryType(subqueryType); result.add(subquery); - extractSubquerySegments(result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect()); } if (expressionSegment instanceof ExistsSubqueryExpression) { SubquerySegment subquery = ((ExistsSubqueryExpression) expressionSegment).getSubquery(); subquery.setSubqueryType(subqueryType); result.add(subquery); - extractSubquerySegments(result, subquery.getSelect()); + extractRecursive(needRecursive, result, subquery.getSelect()); } if (expressionSegment instanceof ListExpression) { - ((ListExpression) expressionSegment).getItems().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType)); + ((ListExpression) expressionSegment).getItems().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType, needRecursive)); } if (expressionSegment instanceof BinaryOperationExpression) { - extractSubquerySegmentsFromExpression(result, ((BinaryOperationExpression) expressionSegment).getLeft(), subqueryType); - extractSubquerySegmentsFromExpression(result, ((BinaryOperationExpression) expressionSegment).getRight(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((BinaryOperationExpression) expressionSegment).getLeft(), subqueryType, needRecursive); + extractSubquerySegmentsFromExpression(result, ((BinaryOperationExpression) expressionSegment).getRight(), subqueryType, needRecursive); } if (expressionSegment instanceof InExpression) { - extractSubquerySegmentsFromExpression(result, ((InExpression) expressionSegment).getLeft(), subqueryType); - extractSubquerySegmentsFromExpression(result, ((InExpression) expressionSegment).getRight(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((InExpression) expressionSegment).getLeft(), subqueryType, needRecursive); + extractSubquerySegmentsFromExpression(result, ((InExpression) expressionSegment).getRight(), subqueryType, needRecursive); } if (expressionSegment instanceof BetweenExpression) { - extractSubquerySegmentsFromExpression(result, ((BetweenExpression) expressionSegment).getBetweenExpr(), subqueryType); - extractSubquerySegmentsFromExpression(result, ((BetweenExpression) expressionSegment).getAndExpr(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((BetweenExpression) expressionSegment).getBetweenExpr(), subqueryType, needRecursive); + extractSubquerySegmentsFromExpression(result, ((BetweenExpression) expressionSegment).getAndExpr(), subqueryType, needRecursive); } if (expressionSegment instanceof NotExpression) { - extractSubquerySegmentsFromExpression(result, ((NotExpression) expressionSegment).getExpression(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((NotExpression) expressionSegment).getExpression(), subqueryType, needRecursive); } if (expressionSegment instanceof FunctionSegment) { - ((FunctionSegment) expressionSegment).getParameters().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType)); + ((FunctionSegment) expressionSegment).getParameters().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType, needRecursive)); } if (expressionSegment instanceof MatchAgainstExpression) { - extractSubquerySegmentsFromExpression(result, ((MatchAgainstExpression) expressionSegment).getExpr(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((MatchAgainstExpression) expressionSegment).getExpr(), subqueryType, needRecursive); } if (expressionSegment instanceof CaseWhenExpression) { - extractSubquerySegmentsFromCaseWhenExpression(result, (CaseWhenExpression) expressionSegment, subqueryType); + extractSubquerySegmentsFromCaseWhenExpression(result, (CaseWhenExpression) expressionSegment, subqueryType, needRecursive); } if (expressionSegment instanceof CollateExpression) { - extractSubquerySegmentsFromExpression(result, ((CollateExpression) expressionSegment).getCollateName(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((CollateExpression) expressionSegment).getCollateName(), subqueryType, needRecursive); } if (expressionSegment instanceof DatetimeExpression) { - extractSubquerySegmentsFromExpression(result, ((DatetimeExpression) expressionSegment).getLeft(), subqueryType); - extractSubquerySegmentsFromExpression(result, ((DatetimeExpression) expressionSegment).getRight(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((DatetimeExpression) expressionSegment).getLeft(), subqueryType, needRecursive); + extractSubquerySegmentsFromExpression(result, ((DatetimeExpression) expressionSegment).getRight(), subqueryType, needRecursive); } if (expressionSegment instanceof NotExpression) { - extractSubquerySegmentsFromExpression(result, ((NotExpression) expressionSegment).getExpression(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((NotExpression) expressionSegment).getExpression(), subqueryType, needRecursive); } if (expressionSegment instanceof TypeCastExpression) { - extractSubquerySegmentsFromExpression(result, ((TypeCastExpression) expressionSegment).getExpression(), subqueryType); + extractSubquerySegmentsFromExpression(result, ((TypeCastExpression) expressionSegment).getExpression(), subqueryType, needRecursive); } } - private static void extractSubquerySegmentsFromCaseWhenExpression(final List result, final CaseWhenExpression expressionSegment, final SubqueryType subqueryType) { - extractSubquerySegmentsFromExpression(result, expressionSegment.getCaseExpr(), subqueryType); - expressionSegment.getWhenExprs().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType)); - expressionSegment.getThenExprs().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType)); - extractSubquerySegmentsFromExpression(result, expressionSegment.getElseExpr(), subqueryType); + private static void extractSubquerySegmentsFromCaseWhenExpression(final List result, final CaseWhenExpression expressionSegment, final SubqueryType subqueryType, + final boolean needRecursive) { + extractSubquerySegmentsFromExpression(result, expressionSegment.getCaseExpr(), subqueryType, needRecursive); + expressionSegment.getWhenExprs().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType, needRecursive)); + expressionSegment.getThenExprs().forEach(each -> extractSubquerySegmentsFromExpression(result, each, subqueryType, needRecursive)); + extractSubquerySegmentsFromExpression(result, expressionSegment.getElseExpr(), subqueryType, needRecursive); } - private static void extractSubquerySegmentsFromCombine(final List result, final CombineSegment combineSegment) { + private static void extractSubquerySegmentsFromCombine(final List result, final CombineSegment combineSegment, final boolean needRecursive) { result.add(combineSegment.getLeft()); result.add(combineSegment.getRight()); - extractSubquerySegments(result, combineSegment.getLeft().getSelect()); - extractSubquerySegments(result, combineSegment.getRight().getSelect()); + extractRecursive(needRecursive, result, combineSegment.getLeft().getSelect()); + extractRecursive(needRecursive, result, combineSegment.getRight().getSelect()); } } diff --git a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/WhereExtractUtils.java b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/WhereExtractUtils.java index 7bb931b867875..7046dffeaa460 100644 --- a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/WhereExtractUtils.java +++ b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/WhereExtractUtils.java @@ -71,24 +71,10 @@ private static WhereSegment generateWhereSegment(final JoinTableSegment joinTabl */ public static Collection getSubqueryWhereSegments(final SelectStatement selectStatement) { Collection result = new LinkedList<>(); - for (SubquerySegment each : SubqueryExtractUtils.getSubquerySegments(selectStatement)) { + for (SubquerySegment each : SubqueryExtractUtils.getSubquerySegments(selectStatement, false)) { each.getSelect().getWhere().ifPresent(result::add); result.addAll(getJoinWhereSegments(each.getSelect())); } return result; } - - /** - * Get subquery where segment without join conditions from SelectStatement. - * - * @param selectStatement SelectStatement - * @return subquery where segment collection - */ - public static Collection getSubqueryWhereSegmentsWithoutJoinConditions(final SelectStatement selectStatement) { - Collection result = new LinkedList<>(); - for (SubquerySegment each : SubqueryExtractUtils.getSubquerySegments(selectStatement)) { - each.getSelect().getWhere().ifPresent(result::add); - } - return result; - } } diff --git a/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtilsTest.java b/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtilsTest.java index 9ac7bc44264f2..a2a3eabe7eaf1 100644 --- a/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtilsTest.java +++ b/parser/sql/statement/core/src/test/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtilsTest.java @@ -72,7 +72,7 @@ void assertGetSubquerySegmentsInWhere() { SubqueryExpressionSegment right = new SubqueryExpressionSegment(new SubquerySegment(51, 100, subquerySelectStatement, "")); WhereSegment whereSegment = new WhereSegment(34, 100, new BinaryOperationExpression(40, 100, left, right, "=", "order_id = (SELECT order_id FROM t_order WHERE status = 'OK')")); when(selectStatement.getWhere()).thenReturn(Optional.of(whereSegment)); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(1)); assertThat(actual.iterator().next(), is(right.getSubquery())); } @@ -89,7 +89,7 @@ void assertGetSubquerySegmentsInProjection() { ProjectionsSegment projections = new ProjectionsSegment(7, 79); when(selectStatement.getProjections()).thenReturn(projections); projections.getProjections().add(subqueryProjectionSegment); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(1)); assertThat(actual.iterator().next(), is(subquerySegment)); } @@ -110,7 +110,7 @@ void assertGetSubquerySegmentsInFrom1() { projections.getProjections().add(new ColumnProjectionSegment(new ColumnSegment(7, 16, new IdentifierValue("order_id")))); SubqueryTableSegment subqueryTableSegment = new SubqueryTableSegment(0, 0, new SubquerySegment(23, 71, subquery, "")); when(selectStatement.getFrom()).thenReturn(Optional.of(subqueryTableSegment)); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(1)); assertThat(actual.iterator().next(), is(subqueryTableSegment.getSubquery())); } @@ -154,7 +154,7 @@ void assertGetSubquerySegmentsInFrom2() { from.setLeft(leftSubquerySegment); from.setRight(rightSubquerySegment); when(selectStatement.getFrom()).thenReturn(Optional.of(from)); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(2)); Iterator iterator = actual.iterator(); assertThat(iterator.next(), is(leftSubquerySegment.getSubquery())); @@ -166,7 +166,7 @@ void assertGetSubquerySegmentsWithMultiNestedSubquery() { SelectStatement selectStatement = mock(SelectStatement.class); SubquerySegment subquerySelect = createSubquerySegment(); when(selectStatement.getFrom()).thenReturn(Optional.of(new SubqueryTableSegment(0, 0, subquerySelect))); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(2)); } @@ -176,7 +176,7 @@ void assertGetSubquerySegmentsWithCombineSegment() { SubquerySegment left = new SubquerySegment(0, 0, mock(SelectStatement.class), ""); SubquerySegment right = createSubquerySegment(); when(selectStatement.getCombine()).thenReturn(Optional.of(new CombineSegment(0, 0, left, CombineType.UNION, right))); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(3)); } @@ -197,7 +197,7 @@ void assertGetSubquerySegmentsFromProjectionFunctionParams() { functionSegment.getParameters().add(new SubqueryExpressionSegment(new SubquerySegment(0, 0, mock(SelectStatement.class), ""))); ExpressionProjectionSegment expressionProjectionSegment = new ExpressionProjectionSegment(0, 0, "", functionSegment); projections.getProjections().add(expressionProjectionSegment); - Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement); + Collection actual = SubqueryExtractUtils.getSubquerySegments(selectStatement, true); assertThat(actual.size(), is(1)); } }