Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
287 changes: 122 additions & 165 deletions rewrite-java/src/main/java/org/openrewrite/java/InlineMethodCalls.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@
*/
package org.openrewrite.java;

import lombok.AccessLevel;
import lombok.Getter;
import lombok.EqualsAndHashCode;
import lombok.Value;
import org.jspecify.annotations.Nullable;
import org.openrewrite.Cursor;
import org.openrewrite.ExecutionContext;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.*;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.*;

import java.util.*;
Expand All @@ -32,89 +29,81 @@
import static java.lang.String.format;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toMap;
import static java.util.stream.Collectors.toSet;

@Incubating(since = "8.63.0")
@EqualsAndHashCode(callSuper = false)
@Value
public class InlineMethodCalls extends Recipe {
private static final Pattern TEMPLATE_IDENTIFIER = Pattern.compile("#\\{(\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*):any\\(.*?\\)}");

private static final String INLINE_ME = "InlineMe";
@Option(displayName = "Method pattern",
description = "A method pattern that is used to find matching method invocations.",
example = "com.google.common.base.Preconditions checkNotNull(..)")
String methodPattern;

@Option(displayName = "Replacement template",
description = "The replacement template for the method invocation. Parameters can be referenced using their names from the original method.",
example = "java.util.Objects.requireNonNull(#{p0})")
String replacement;

@Option(displayName = "Imports",
description = "List of regular imports to add when the replacement is made.",
required = false,
example = "[\"java.util.Objects\"]")
@Nullable
Set<String> imports;

@Option(displayName = "Static imports",
description = "List of static imports to add when the replacement is made.",
required = false,
example = "[\"java.util.Collections.emptyList\"]")
@Nullable
Set<String> staticImports;

@Option(displayName = "Classpath from resources",
description = "List of paths to JAR files on the classpath for parsing the replacement template.",
required = false,
example = "[\"guava-33.4.8-jre\"]")
@Nullable
Set<String> classpathFromResources;

@Override
public String getDisplayName() {
return "Inline methods annotated with `@InlineMe`";
return "Inline method calls";
}

@Override
public String getDescription() {
return "Apply inlinings as defined by Error Prone's [`@InlineMe` annotation](https://errorprone.info/docs/inlineme), " +
"or compatible annotations. Uses the template and method arguments to replace method calls. " +
"Supports both methods invocations and constructor calls, with optional new imports.";
return "Inline method calls using a template replacement pattern. " +
"Supports both method invocations and constructor calls, with optional imports.";
}

@Override
public TreeVisitor<?, ExecutionContext> getVisitor() {
// XXX Preconditions can not yet pick up the `@InlineMe` annotation on methods used
return new JavaVisitor<ExecutionContext>() {
MethodMatcher matcher = new MethodMatcher(methodPattern, true);
return Preconditions.check(new UsesMethod<>(methodPattern), new JavaVisitor<ExecutionContext>() {
@Override
public J visitMethodInvocation(J.MethodInvocation method, ExecutionContext ctx) {
J.MethodInvocation mi = (J.MethodInvocation) super.visitMethodInvocation(method, ctx);
InlineMeValues values = findInlineMeValues(mi.getMethodType());
if (values == null) {
return mi;
if (matcher.matches(method)) {
return replaceMethodCall(method, ctx);
}
Template template = values.template(mi);
if (template == null) {
return mi;
}
removeAndAddImports(method, values.getImports(), values.getStaticImports());
J replacement = JavaTemplate.builder(template.getString())
.contextSensitive()
.imports(values.getImports().toArray(new String[0]))
.staticImports(values.getStaticImports().toArray(new String[0]))
.javaParser(JavaParser.fromJavaVersion().classpath(JavaParser.runtimeClasspath()))
.build()
.apply(updateCursor(mi), mi.getCoordinates().replace(), template.getParameters());
return avoidMethodSelfReferences(mi, replacement);
return super.visitMethodInvocation(method, ctx);
}

@Override
public J visitNewClass(J.NewClass newClass, ExecutionContext ctx) {
J.NewClass nc = (J.NewClass) super.visitNewClass(newClass, ctx);
InlineMeValues values = findInlineMeValues(nc.getConstructorType());
if (values == null) {
return nc;
}
Template template = values.template(nc);
if (template == null) {
return nc;
if (matcher.matches(newClass)) {
return replaceMethodCall(newClass, ctx);
}
removeAndAddImports(newClass, values.getImports(), values.getStaticImports());
J replacement = JavaTemplate.builder(template.getString())
.contextSensitive()
.imports(values.getImports().toArray(new String[0]))
.staticImports(values.getStaticImports().toArray(new String[0]))
.javaParser(JavaParser.fromJavaVersion().classpath(JavaParser.runtimeClasspath()))
.build()
.apply(updateCursor(nc), nc.getCoordinates().replace(), template.getParameters());
return avoidMethodSelfReferences(nc, replacement);
return super.visitNewClass(newClass, ctx);
}

private @Nullable InlineMeValues findInlineMeValues(JavaType.@Nullable Method methodType) {
if (methodType == null) {
return null;
}
List<String> parameterNames = methodType.getParameterNames();
if (!parameterNames.isEmpty() && "arg0".equals(parameterNames.get(0))) {
return null; // We need `-parameters` before we're able to substitute parameters in the template
}

List<JavaType.FullyQualified> annotations = methodType.getAnnotations();
for (JavaType.FullyQualified annotation : annotations) {
if (INLINE_ME.equals(annotation.getClassName())) {
return InlineMeValues.parse((JavaType.Annotation) annotation);
}
}
return null;
private J replaceMethodCall(MethodCall methodCall, ExecutionContext ctx) {
Set<String> importsSet = imports != null ? imports : emptySet();
Set<String> staticImportsSet = staticImports != null ? staticImports : emptySet();
removeAndAddImports(methodCall, importsSet, staticImportsSet);
J applied = applyJavaTemplate(methodCall, getCursor(), importsSet, staticImportsSet, ctx);
return avoidMethodSelfReferences(methodCall, applied);
}

private void removeAndAddImports(MethodCall method, Set<String> templateImports, Set<String> templateStaticImports) {
Expand Down Expand Up @@ -152,10 +141,10 @@ private Set<String> findOriginalImports(MethodCall method) {
// Collect all regular and static imports used in the original method call
return new JavaVisitor<Set<String>>() {
@Override
public @Nullable JavaType visitType(@Nullable JavaType javaType, Set<String> strings) {
JavaType jt = super.visitType(javaType, strings);
public @Nullable JavaType visitType(@Nullable JavaType javaType, Set<String> imports) {
JavaType jt = super.visitType(javaType, imports);
if (jt instanceof JavaType.FullyQualified) {
strings.add(((JavaType.FullyQualified) jt).getFullyQualifiedName());
imports.add(((JavaType.FullyQualified) jt).getFullyQualifiedName());
}
return jt;
}
Expand Down Expand Up @@ -190,6 +179,70 @@ public J visitIdentifier(J.Identifier identifier, Set<String> staticImports) {
}.reduce(method, new HashSet<>());
}

J applyJavaTemplate(MethodCall methodCall, Cursor cursor, Set<String> importsSet, Set<String> staticImportsSet, ExecutionContext ctx) {
JavaType.Method methodType = requireNonNull(methodCall.getMethodType());
String string = createTemplateString(methodCall, methodType);
Object[] parameters = createParameters(string, methodCall);

JavaTemplate.Builder templateBuilder = JavaTemplate.builder(string)
.contextSensitive()
.imports(importsSet.toArray(new String[0]))
.staticImports(staticImportsSet.toArray(new String[0]));
if (classpathFromResources != null && !classpathFromResources.isEmpty()) {
templateBuilder.javaParser(JavaParser.fromJavaVersion()
.classpathFromResources(ctx, classpathFromResources.toArray(new String[0])));
}
return templateBuilder.build()
.apply(cursor, methodCall.getCoordinates().replace(), parameters);
}

private String createTemplateString(MethodCall original, JavaType.Method methodType) {
String templateString;
if (original instanceof J.NewClass && replacement.startsWith("this(")) {
// For constructor-to-constructor replacement, replace "this" with "new ClassName"
templateString = "new " + methodType.getDeclaringType().getClassName() + replacement.substring(4);
} else if (original instanceof J.MethodInvocation &&
((J.MethodInvocation) original).getSelect() == null &&
replacement.startsWith("this.")) {
templateString = replacement.substring(5);
} else {
templateString = replacement.replaceAll("\\bthis\\b", "#{this:any()}");
}
List<String> originalParameterNames = methodType.getParameterNames();
for (String parameterName : originalParameterNames) {
// Replace parameter names with their values in the templateString
templateString = templateString
.replaceFirst(format("\\b%s\\b", parameterName), format("#{%s:any()}", parameterName))
.replaceAll(format("(?<!\\{)\\b%s\\b", parameterName), format("#{%s}", parameterName));
}
return templateString;
}

private Object[] createParameters(String templateString, MethodCall original) {
Map<String, Expression> lookup = new HashMap<>();
if (original instanceof J.MethodInvocation) {
Expression select = ((J.MethodInvocation) original).getSelect();
if (select != null) {
lookup.put("this", select);
}
}
List<String> originalParameterNames = requireNonNull(original.getMethodType()).getParameterNames();
for (int i = 0; i < originalParameterNames.size(); i++) {
String originalName = originalParameterNames.get(i);
Expression originalValue = original.getArguments().get(i);
lookup.put(originalName, originalValue);
}
List<Object> parameters = new ArrayList<>();
Matcher matcher = TEMPLATE_IDENTIFIER.matcher(templateString);
while (matcher.find()) {
Expression o = lookup.get(matcher.group(1));
if (o != null) {
parameters.add(o);
}
}
return parameters.toArray();
}

private J avoidMethodSelfReferences(MethodCall original, J replacement) {
JavaType.Method replacementMethodType = replacement instanceof MethodCall ?
((MethodCall) replacement).getMethodType() : null;
Expand All @@ -215,102 +268,6 @@ private J avoidMethodSelfReferences(MethodCall original, J replacement) {
}
return replacement;
}
};
}

@Value
private static class InlineMeValues {
private static final Pattern TEMPLATE_IDENTIFIER = Pattern.compile("#\\{(\\p{javaJavaIdentifierStart}\\p{javaJavaIdentifierPart}*):any\\(.*?\\)}");

@Getter(AccessLevel.NONE)
String replacement;

Set<String> imports;
Set<String> staticImports;

static InlineMeValues parse(JavaType.Annotation annotation) {
Map<String, Object> collect = annotation.getValues().stream().collect(toMap(
e -> ((JavaType.Method) e.getElement()).getName(),
JavaType.Annotation.ElementValue::getValue
));
// Parse imports and static imports from the annotation values
return new InlineMeValues(
(String) collect.get("replacement"),
parseImports(collect.get("imports")),
parseImports(collect.get("staticImports")));
}

private static Set<String> parseImports(@Nullable Object importsValue) {
if (importsValue instanceof List) {
return ((List<?>) importsValue).stream()
.map(Object::toString)
.collect(toSet());
}
return emptySet();
}

@Nullable
Template template(MethodCall original) {
JavaType.Method methodType = original.getMethodType();
if (methodType == null) {
return null;
}
String templateString = createTemplateString(original, replacement, methodType);
List<Object> parameters = createParameters(templateString, original);
return new Template(templateString, parameters.toArray(new Object[0]));
}

private static String createTemplateString(MethodCall original, String replacement, JavaType.Method methodType) {
String templateString;
if (original instanceof J.NewClass && replacement.startsWith("this(")) {
// For constructor-to-constructor replacement, replace "this" with "new ClassName"
templateString = "new " + methodType.getDeclaringType().getClassName() + replacement.substring(4);
} else if (original instanceof J.MethodInvocation &&
((J.MethodInvocation) original).getSelect() == null &&
replacement.startsWith("this.")) {
templateString = replacement.substring(5);
} else {
templateString = replacement.replaceAll("\\bthis\\b", "#{this:any()}");
}
List<String> originalParameterNames = methodType.getParameterNames();
for (String parameterName : originalParameterNames) {
// Replace parameter names with their values in the templateString
templateString = templateString
.replaceFirst(format("\\b%s\\b", parameterName), format("#{%s:any()}", parameterName))
.replaceAll(format("(?<!\\{)\\b%s\\b", parameterName), format("#{%s}", parameterName));
}
return templateString;
}

private static List<Object> createParameters(String templateString, MethodCall original) {
Map<String, Expression> lookup = new HashMap<>();
if (original instanceof J.MethodInvocation) {
Expression select = ((J.MethodInvocation) original).getSelect();
if (select != null) {
lookup.put("this", select);
}
}
List<String> originalParameterNames = requireNonNull(original.getMethodType()).getParameterNames();
for (int i = 0; i < originalParameterNames.size(); i++) {
String originalName = originalParameterNames.get(i);
Expression originalValue = original.getArguments().get(i);
lookup.put(originalName, originalValue);
}
List<Object> parameters = new ArrayList<>();
Matcher matcher = TEMPLATE_IDENTIFIER.matcher(templateString);
while (matcher.find()) {
Expression o = lookup.get(matcher.group(1));
if (o != null) {
parameters.add(o);
}
}
return parameters;
}
}

@Value
private static class Template {
String string;
Object[] parameters;
});
}
}
Loading