Skip to content

Commit 6b94940

Browse files
authored
ReplaceStreamCollectWithToList should look at return type (#803)
* `ReplaceStreamCollectWithToList` should look at return type * Condense * Condense further
1 parent 5202c07 commit 6b94940

File tree

2 files changed

+118
-0
lines changed

2 files changed

+118
-0
lines changed

src/main/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToList.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
import org.openrewrite.java.search.UsesMethod;
2828
import org.openrewrite.java.tree.Expression;
2929
import org.openrewrite.java.tree.J;
30+
import org.openrewrite.java.tree.JavaType;
31+
import org.openrewrite.java.tree.TypeUtils;
3032

3133
import java.time.Duration;
3234
import java.util.Collections;
@@ -95,12 +97,32 @@ public J.MethodInvocation visitMethodInvocation(J.MethodInvocation method, Execu
9597
Expression command = method.getArguments().get(0);
9698
if (COLLECT_TO_UNMODIFIABLE_LIST.matches(command) ||
9799
convertToList && COLLECT_TO_LIST.matches(command)) {
100+
101+
// Check if the transformation would result in incompatible types
102+
if (!areTypesCompatible(result)) {
103+
return result;
104+
}
105+
98106
maybeRemoveImport("java.util.stream.Collectors");
99107
J.MethodInvocation toList = JavaTemplate.apply("#{any(java.util.stream.Stream)}.toList()",
100108
updateCursor(result), result.getCoordinates().replace(), result.getSelect());
101109
return toList.getPadding().withSelect(result.getPadding().getSelect());
102110
}
103111
return result;
104112
}
113+
114+
private boolean areTypesCompatible(J.MethodInvocation method) {
115+
if (method.getSelect() == null ||
116+
method.getSelect().getType() == null ||
117+
!(method.getSelect().getType() instanceof JavaType.Parameterized) ||
118+
!(method.getType() instanceof JavaType.Parameterized)) {
119+
return false;
120+
}
121+
// Check if the stream element type and expected list element type are exactly the same
122+
// If they differ (e.g., Stream<Integer> but List<Number>), don't transform
123+
return TypeUtils.isOfType(
124+
((JavaType.Parameterized) method.getSelect().getType()).getTypeParameters().get(0),
125+
((JavaType.Parameterized) method.getType()).getTypeParameters().get(0));
126+
}
105127
}
106128
}

src/test/java/org/openrewrite/java/migrate/util/ReplaceStreamCollectWithToListTest.java

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import org.junit.jupiter.api.Test;
1919
import org.openrewrite.DocumentExample;
20+
import org.openrewrite.Issue;
2021
import org.openrewrite.test.RecipeSpec;
2122
import org.openrewrite.test.RewriteTest;
2223

@@ -180,4 +181,99 @@ List<String> test(Stream<String> stream) {
180181
);
181182
}
182183

184+
@Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791")
185+
@Test
186+
void doesNotReplaceWhenReturnTypeIsIncompatible() {
187+
rewriteRun(
188+
//language=java
189+
java(
190+
"""
191+
import java.util.stream.Collectors;
192+
import java.util.stream.Stream;
193+
import java.util.List;
194+
195+
class Example {
196+
List<Number> foo() {
197+
return Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList());
198+
}
199+
}
200+
"""
201+
)
202+
);
203+
}
204+
205+
@Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791")
206+
@Test
207+
void replacesWhenTypesAreCompatible() {
208+
rewriteRun(
209+
//language=java
210+
java(
211+
"""
212+
import java.util.stream.Collectors;
213+
import java.util.stream.Stream;
214+
import java.util.List;
215+
216+
class Example {
217+
List<Integer> foo() {
218+
return Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList());
219+
}
220+
}
221+
""",
222+
"""
223+
import java.util.stream.Stream;
224+
import java.util.List;
225+
226+
class Example {
227+
List<Integer> foo() {
228+
return Stream.of(Integer.valueOf(1)).toList();
229+
}
230+
}
231+
"""
232+
)
233+
);
234+
}
235+
236+
@Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791")
237+
@Test
238+
void doesNotReplaceInVariableAssignmentWithIncompatibleTypes() {
239+
rewriteRun(
240+
//language=java
241+
java(
242+
"""
243+
import java.util.stream.Collectors;
244+
import java.util.stream.Stream;
245+
import java.util.List;
246+
247+
class Example {
248+
void foo() {
249+
List<Number> numbers = Stream.of(Integer.valueOf(1)).collect(Collectors.toUnmodifiableList());
250+
}
251+
}
252+
"""
253+
)
254+
);
255+
}
256+
257+
@Issue("https://github.com/openrewrite/rewrite-migrate-java/issues/791")
258+
@Test
259+
void doesNotReplaceWithToListWhenConvertToListFlagIsTrue() {
260+
rewriteRun(
261+
recipeSpec -> recipeSpec.recipe(new ReplaceStreamCollectWithToList(true)),
262+
//language=java
263+
java(
264+
"""
265+
import java.util.stream.Collectors;
266+
import java.util.stream.Stream;
267+
import java.util.List;
268+
269+
class Example {
270+
List<Number> foo() {
271+
return Stream.of(Integer.valueOf(1)).collect(Collectors.toList());
272+
}
273+
}
274+
"""
275+
)
276+
);
277+
}
278+
183279
}

0 commit comments

Comments
 (0)