Skip to content

Commit 0d8d873

Browse files
punkratz312pinguin3245678timtebeek
authored
EqualsAvoidsNull should flip arguments for constants (#398)
* replaceMethodArgs * replaceMethodArgs * replaceMethodArgs * replaceMethodArgs * replaceMethodArg * undo * undo * add String foo, String bar * multiple * add replaceMethodArg * Also place field accesses first * Check flags on fieldAccess.name.fieldType * Add test showing no change when not static & final * Also support static imports * Remove unused import --------- Co-authored-by: Vincent Potucek <vincent.potucek@sap.com> Co-authored-by: Tim te Beek <tim@moderne.io>
1 parent 0b00944 commit 0d8d873

File tree

3 files changed

+179
-37
lines changed

3 files changed

+179
-37
lines changed

src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNull.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.jspecify.annotations.Nullable;
1919
import org.openrewrite.*;
2020
import org.openrewrite.java.JavaIsoVisitor;
21+
import org.openrewrite.java.search.UsesMethod;
2122
import org.openrewrite.java.style.Checkstyle;
2223
import org.openrewrite.java.style.EqualsAvoidsNullStyle;
2324
import org.openrewrite.java.tree.J;
@@ -53,7 +54,7 @@ public Duration getEstimatedEffortPerOccurrence() {
5354

5455
@Override
5556
public TreeVisitor<?, ExecutionContext> getVisitor() {
56-
return new JavaIsoVisitor<ExecutionContext>() {
57+
JavaIsoVisitor<ExecutionContext> replacementVisitor = new JavaIsoVisitor<ExecutionContext>() {
5758
@Override
5859
public J visit(@Nullable Tree tree, ExecutionContext ctx) {
5960
if (tree instanceof JavaSourceFile) {
@@ -68,5 +69,12 @@ public J visit(@Nullable Tree tree, ExecutionContext ctx) {
6869
return (J) tree;
6970
}
7071
};
72+
return Preconditions.check(
73+
Preconditions.or(
74+
new UsesMethod<>("java.lang.String equals*(..)"),
75+
new UsesMethod<>("java.lang.String co*(..)")
76+
),
77+
replacementVisitor
78+
);
7179
}
7280
}

src/main/java/org/openrewrite/staticanalysis/EqualsAvoidsNullVisitor.java

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,37 @@ public class EqualsAvoidsNullVisitor<P> extends JavaVisitor<P> {
5757
@Override
5858
public J visitMethodInvocation(J.MethodInvocation method, P p) {
5959
J.MethodInvocation m = (J.MethodInvocation) super.visitMethodInvocation(method, p);
60-
if (m.getSelect() != null &&
61-
!(m.getSelect() instanceof J.Literal) &&
62-
!m.getArguments().isEmpty() &&
63-
m.getArguments().get(0) instanceof J.Literal &&
64-
isStringComparisonMethod(m)) {
65-
return literalsFirstInComparisonsBinaryCheck(m, getCursor().getParentTreeCursor().getValue());
60+
if (m.getSelect() != null && !(m.getSelect() instanceof J.Literal) &&
61+
isStringComparisonMethod(m) && hasCompatibleArgument(m)) {
62+
63+
maybeHandleParentBinary(m);
64+
65+
Expression firstArgument = m.getArguments().get(0);
66+
return firstArgument.getType() == JavaType.Primitive.Null ?
67+
literalsFirstInComparisonsNull(m, firstArgument) :
68+
literalsFirstInComparisons(m, firstArgument);
6669
}
6770
return m;
6871
}
6972

73+
private boolean hasCompatibleArgument(J.MethodInvocation m) {
74+
if (m.getArguments().isEmpty()) {
75+
return false;
76+
}
77+
Expression firstArgument = m.getArguments().get(0);
78+
if (firstArgument instanceof J.Literal) {
79+
return true;
80+
}
81+
if (firstArgument instanceof J.FieldAccess) {
82+
firstArgument = ((J.FieldAccess) firstArgument).getName();
83+
}
84+
if (firstArgument instanceof J.Identifier) {
85+
JavaType.Variable fieldType = ((J.Identifier) firstArgument).getFieldType();
86+
return fieldType != null && fieldType.hasFlags(Flag.Static, Flag.Final);
87+
}
88+
return false;
89+
}
90+
7091
private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) {
7192
return EQUALS.matches(methodInvocation) ||
7293
!style.getIgnoreEqualsIgnoreCase() &&
@@ -76,17 +97,26 @@ private boolean isStringComparisonMethod(J.MethodInvocation methodInvocation) {
7697
CONTENT_EQUALS.matches(methodInvocation);
7798
}
7899

79-
private Expression literalsFirstInComparisonsBinaryCheck(J.MethodInvocation m, P parent) {
100+
private void maybeHandleParentBinary(J.MethodInvocation m) {
101+
P parent = getCursor().getParentTreeCursor().getValue();
80102
if (parent instanceof J.Binary) {
81-
handleBinaryExpression(m, (J.Binary) parent);
103+
if (((J.Binary) parent).getOperator() == J.Binary.Type.And && ((J.Binary) parent).getLeft() instanceof J.Binary) {
104+
J.Binary potentialNullCheck = (J.Binary) ((J.Binary) parent).getLeft();
105+
if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) ||
106+
isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) {
107+
doAfterVisit(new RemoveUnnecessaryNullCheck<>((J.Binary) parent));
108+
}
109+
}
82110
}
83-
return getExpression(m, m.getArguments().get(0));
84111
}
85112

86-
private static Expression getExpression(J.MethodInvocation m, Expression firstArgument) {
87-
return firstArgument.getType() == JavaType.Primitive.Null ?
88-
literalsFirstInComparisonsNull(m, firstArgument) :
89-
literalsFirstInComparisons(m, firstArgument);
113+
private boolean isNullLiteral(Expression expression) {
114+
return expression instanceof J.Literal && ((J.Literal) expression).getType() == JavaType.Primitive.Null;
115+
}
116+
117+
private boolean matchesSelect(Expression expression, Expression select) {
118+
return expression.printTrimmed(getCursor()).replaceAll("\\s", "")
119+
.equals(select.printTrimmed(getCursor()).replaceAll("\\s", ""));
90120
}
91121

92122
private static J.Binary literalsFirstInComparisonsNull(J.MethodInvocation m, Expression firstArgument) {
@@ -104,25 +134,6 @@ private static J.MethodInvocation literalsFirstInComparisons(J.MethodInvocation
104134
.withArguments(singletonList(m.getSelect().withPrefix(Space.EMPTY)));
105135
}
106136

107-
private void handleBinaryExpression(J.MethodInvocation m, J.Binary binary) {
108-
if (binary.getOperator() == J.Binary.Type.And && binary.getLeft() instanceof J.Binary) {
109-
J.Binary potentialNullCheck = (J.Binary) binary.getLeft();
110-
if (isNullLiteral(potentialNullCheck.getLeft()) && matchesSelect(potentialNullCheck.getRight(), requireNonNull(m.getSelect())) ||
111-
isNullLiteral(potentialNullCheck.getRight()) && matchesSelect(potentialNullCheck.getLeft(), requireNonNull(m.getSelect()))) {
112-
doAfterVisit(new RemoveUnnecessaryNullCheck<>(binary));
113-
}
114-
}
115-
}
116-
117-
private boolean isNullLiteral(Expression expression) {
118-
return expression instanceof J.Literal && ((J.Literal) expression).getType() == JavaType.Primitive.Null;
119-
}
120-
121-
private boolean matchesSelect(Expression expression, Expression select) {
122-
return expression.printTrimmed(getCursor()).replaceAll("\\s", "")
123-
.equals(select.printTrimmed(getCursor()).replaceAll("\\s", ""));
124-
}
125-
126137
private static class RemoveUnnecessaryNullCheck<P> extends JavaVisitor<P> {
127138

128139
private final J.Binary scope;

src/test/java/org/openrewrite/staticanalysis/EqualsAvoidsNullTest.java

Lines changed: 127 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
*/
1616
package org.openrewrite.staticanalysis;
1717

18+
import org.junit.jupiter.api.Nested;
1819
import org.junit.jupiter.api.Test;
1920
import org.openrewrite.DocumentExample;
21+
import org.openrewrite.Issue;
2022
import org.openrewrite.test.RecipeSpec;
2123
import org.openrewrite.test.RewriteTest;
2224

@@ -94,17 +96,16 @@ public class A {
9496
@Test
9597
void nullLiteral() {
9698
rewriteRun(
97-
//language=java
98-
java("""
99+
//language=java
100+
java("""
99101
public class A {
100102
void foo(String s) {
101103
if(s.equals(null)) {
102104
}
103105
}
104106
}
105107
""",
106-
"""
107-
108+
"""
108109
public class A {
109110
void foo(String s) {
110111
if(s == null) {
@@ -114,4 +115,126 @@ void foo(String s) {
114115
""")
115116
);
116117
}
118+
119+
@Nested
120+
class ReplaceConstantMethodArg {
121+
122+
@Issue("https://github.com/openrewrite/rewrite-static-analysis/pull/398")
123+
@Test
124+
void one() {
125+
rewriteRun(
126+
// language=java
127+
java(
128+
"""
129+
public class Constants {
130+
public static final String FOO = "FOO";
131+
}
132+
class A {
133+
private boolean isFoo(String foo) {
134+
return foo.contentEquals(Constants.FOO);
135+
}
136+
}
137+
""",
138+
"""
139+
public class Constants {
140+
public static final String FOO = "FOO";
141+
}
142+
class A {
143+
private boolean isFoo(String foo) {
144+
return Constants.FOO.contentEquals(foo);
145+
}
146+
}
147+
"""
148+
)
149+
);
150+
}
151+
152+
@Test
153+
void staticImport() {
154+
rewriteRun(
155+
// language=java
156+
java(
157+
"""
158+
package c;
159+
public class Constants {
160+
public static final String FOO = "FOO";
161+
}
162+
"""
163+
),
164+
// language=java
165+
java(
166+
"""
167+
import static c.Constants.FOO;
168+
class A {
169+
private boolean isFoo(String foo) {
170+
return foo.contentEquals(FOO);
171+
}
172+
}
173+
""",
174+
"""
175+
import static c.Constants.FOO;
176+
class A {
177+
private boolean isFoo(String foo) {
178+
return FOO.contentEquals(foo);
179+
}
180+
}
181+
"""
182+
)
183+
);
184+
}
185+
186+
@Test
187+
void multiple() {
188+
rewriteRun(
189+
//language=java
190+
java(
191+
"""
192+
public class Constants {
193+
public static final String FOO = "FOO";
194+
}
195+
class A {
196+
private boolean isFoo(String foo, String bar) {
197+
return foo.contentEquals(Constants.FOO)
198+
|| bar.compareToIgnoreCase(Constants.FOO);
199+
}
200+
}
201+
""",
202+
"""
203+
public class Constants {
204+
public static final String FOO = "FOO";
205+
}
206+
class A {
207+
private boolean isFoo(String foo, String bar) {
208+
return Constants.FOO.contentEquals(foo)
209+
|| Constants.FOO.compareToIgnoreCase(bar);
210+
}
211+
}
212+
"""
213+
)
214+
);
215+
}
216+
217+
@Test
218+
void nonStaticNonFinalNoChange() {
219+
rewriteRun(
220+
// language=java
221+
java(
222+
"""
223+
public class Constants {
224+
public final String FOO = "FOO";
225+
public static String BAR = "BAR";
226+
}
227+
class A {
228+
private boolean isFoo(String foo) {
229+
return foo.contentEquals(new Constants().FOO);
230+
}
231+
private boolean isBar(String bar) {
232+
return bar.contentEquals(Constants.BAR);
233+
}
234+
}
235+
"""
236+
)
237+
);
238+
}
239+
}
117240
}

0 commit comments

Comments
 (0)