Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.search.SemanticallyEqual;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.java.search.UsesJavaVersion;
import org.openrewrite.java.tree.*;
import org.openrewrite.staticanalysis.groovy.GroovyFileChecker;
import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -58,7 +60,13 @@ public Duration getEstimatedEffortPerOccurrence() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaVisitor<ExecutionContext>() {
TreeVisitor<?, ExecutionContext> preconditions = Preconditions.and(
new UsesJavaVersion<>(21),
Preconditions.not(new KotlinFileChecker<>()),
Preconditions.not(new GroovyFileChecker<>())
);

return Preconditions.check(preconditions, new JavaVisitor<ExecutionContext>() {
@Override
public J visitBlock(J.Block block, ExecutionContext ctx) {
AtomicReference<@Nullable NullCheck> nullCheck = new AtomicReference<>();
Expand All @@ -68,13 +76,18 @@ public J visitBlock(J.Block block, ExecutionContext ctx) {
if (nullCheckOpt.isPresent()) {
NullCheck check = nullCheckOpt.get();
J nextStatement = index + 1 < block.getStatements().size() ? block.getStatements().get(index + 1) : null;
if (!(nextStatement instanceof J.Switch) ||
hasNullCase((J.Switch) nextStatement) ||
!SemanticallyEqual.areEqual(((J.Switch) nextStatement).getSelector().getTree(), check.getNullCheckedParameter()) ||
check.returns() ||
check.couldModifyNullCheckedValue()) {
if (!(nextStatement instanceof J.Switch) || check.returns() || check.couldModifyNullCheckedValue()) {
return statement;
}
J.Switch nextSwitch = (J.Switch) nextStatement;
// Only if the switch does not have a null case and switches on the same value as the null check, we can remove the null check
// It must have all possible input values covered
if (hasNullCase(nextSwitch) ||
!SemanticallyEqual.areEqual(nextSwitch.getSelector().getTree(), check.getNullCheckedParameter()) ||
!coversAllPossibleValues(nextSwitch)) {
return statement;
}

nullCheck.set(check);
return null;
}
Expand Down Expand Up @@ -106,6 +119,16 @@ private boolean hasNullCase(J.Switch switch_) {
}

private J.Case createNullCase(J.Switch aSwitch, Statement whenNull) {
J.Case currentFirstCase = aSwitch.getCases().getStatements().isEmpty() ||
!(aSwitch.getCases().getStatements().get(0) instanceof J.Case) ?
null : (J.Case) aSwitch.getCases().getStatements().get(0);
if (currentFirstCase == null || J.Case.Type.Rule == currentFirstCase.getType()) {
return createCaseRule(aSwitch, whenNull);
}
return createCaseStatement(aSwitch, whenNull, currentFirstCase);
}

private J.Case createCaseRule(J.Switch aSwitch, Statement whenNull) {
if (whenNull instanceof J.Block && ((J.Block) whenNull).getStatements().size() == 1) {
Statement firstStatement = ((J.Block) whenNull).getStatements().get(0);
if (firstStatement instanceof Expression || firstStatement instanceof J.Throw) {
Expand All @@ -122,6 +145,60 @@ private J.Case createNullCase(J.Switch aSwitch, Statement whenNull) {
J.Case nullCase = (J.Case) switchWithNullCase.getCases().getStatements().get(0);
return nullCase.withBody(requireNonNull(nullCase.getBody()).withPrefix(Space.SINGLE_SPACE));
}

private J.Case createCaseStatement(J.Switch aSwitch, Statement whenNull, J.Case currentFirstCase) {
List<J> statements = new ArrayList<>();
statements.add(aSwitch.getSelector().getTree());
if (whenNull instanceof J.Block) {
statements.addAll(((J.Block) whenNull).getStatements());
} else {
statements.add(whenNull);
}
StringBuilder template = new StringBuilder("switch(#{any()}) {\ncase null:");
for (int i = 1; i < statements.size(); i++) {
template.append("\n#{any()};");
}
template.append("\nbreak;\n}");
J.Switch switchWithNullCase = JavaTemplate.apply(
template.toString(),
new Cursor(getCursor(), aSwitch),
aSwitch.getCoordinates().replace(),
statements.toArray());
J.Case nullCase = (J.Case) switchWithNullCase.getCases().getStatements().get(0);
Space currentFirstCaseIndentation = currentFirstCase.getStatements().stream().map(J::getPrefix).findFirst().orElse(Space.SINGLE_SPACE);

return nullCase.withStatements(ListUtils.mapFirst(nullCase.getStatements(), s -> s == null ? null : s.withPrefix(currentFirstCaseIndentation)));
}

private boolean coversAllPossibleValues(J.Switch switch_) {
List<J> labels = new ArrayList<>();
for (Statement statement : switch_.getCases().getStatements()) {
for (J j : ((J.Case) statement).getCaseLabels()) {
if (j instanceof J.Identifier && "default".equals(((J.Identifier) j).getSimpleName())) {
return true;
}
labels.add(j);
}
}
JavaType javaType = switch_.getSelector().getTree().getType();
if (javaType instanceof JavaType.Class && ((JavaType.Class) javaType).getKind() == JavaType.FullyQualified.Kind.Enum) {
// Every enum value must be present in the switch
return ((JavaType.Class) javaType).getMembers().stream().allMatch(variable ->
labels.stream().anyMatch(label -> {
if (!(label instanceof TypeTree && TypeUtils.isOfType(((TypeTree) label).getType(), javaType))) {
return false;
}
J.Identifier enumName = null;
if (label instanceof J.Identifier) {
enumName = (J.Identifier) label;
} else if (label instanceof J.FieldAccess) {
enumName = ((J.FieldAccess) label).getName();
}
return enumName != null && Objects.equals(variable.getName(), enumName.getSimpleName());
}));
}
return false;
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import org.openrewrite.*;
import org.openrewrite.internal.ListUtils;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.search.UsesJavaVersion;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.staticanalysis.groovy.GroovyFileChecker;
import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker;

import java.util.ArrayList;
Expand All @@ -50,7 +52,13 @@ public String getDescription() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaIsoVisitor<ExecutionContext>() {
TreeVisitor<?, ExecutionContext> preconditions = Preconditions.and(
new UsesJavaVersion<>(21),
Preconditions.not(new KotlinFileChecker<>()),
Preconditions.not(new GroovyFileChecker<>())
);

return Preconditions.check(preconditions, new JavaIsoVisitor<ExecutionContext>() {
@Override
public J.Switch visitSwitch(J.Switch sw, ExecutionContext ctx) {
J.Switch switch_ = super.visitSwitch(sw, ctx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaVisitor;
import org.openrewrite.java.search.UsesJavaVersion;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.JavaType;
import org.openrewrite.java.tree.TypeUtils;
import org.openrewrite.staticanalysis.groovy.GroovyFileChecker;
import org.openrewrite.staticanalysis.kotlin.KotlinFileChecker;

import static java.util.Collections.singletonList;
Expand All @@ -47,7 +49,13 @@ public String getDescription() {

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(Preconditions.not(new KotlinFileChecker<>()), new JavaIsoVisitor<ExecutionContext>() {
TreeVisitor<?, ExecutionContext> preconditions = Preconditions.and(
new UsesJavaVersion<>(21),
Preconditions.not(new KotlinFileChecker<>()),
Preconditions.not(new GroovyFileChecker<>())
);

return Preconditions.check(preconditions, new JavaIsoVisitor<ExecutionContext>() {

@Override
public J.Case visitCase(J.Case case_, ExecutionContext ctx) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/resources/META-INF/rewrite/java-version-21.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ recipeList:
- org.openrewrite.java.migrate.UpgradePluginsForJava21
- org.openrewrite.java.migrate.DeleteDeprecatedFinalize
- org.openrewrite.java.migrate.RemovedSubjectMethods
# - org.openrewrite.java.migrate.SwitchPatternMatching
#- org.openrewrite.java.migrate.SwitchPatternMatching
#- org.openrewrite.java.migrate.lang.NullCheckAsSwitchCase

---
type: specs.openrewrite.org/v1beta/recipe
Expand Down Expand Up @@ -142,6 +143,5 @@ tags:
- java21
recipeList:
- org.openrewrite.java.migrate.lang.IfElseIfConstructToSwitch
- org.openrewrite.java.migrate.lang.NullCheckAsSwitchCase
- org.openrewrite.java.migrate.lang.RefineSwitchCases
- org.openrewrite.java.migrate.lang.SwitchCaseEnumGuardToLabel
Loading