Skip to content

Commit

Permalink
8322003: JShell - Incorrect type inference in lists of records implem…
Browse files Browse the repository at this point in the history
…enting interfaces

Reviewed-by: vromero
  • Loading branch information
lahodaj committed Jan 8, 2024
1 parent c90768c commit 57a65fe
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,20 @@ public JCExpression parseType(boolean allowVar, List<JCAnnotation> annotations)
return result;
}

protected JCExpression parseIntersectionType(int pos, JCExpression firstType) {
JCExpression t = firstType;
int pos1 = pos;
List<JCExpression> targets = List.of(t);
while (token.kind == AMP) {
accept(AMP);
targets = targets.prepend(parseType());
}
if (targets.length() > 1) {
t = toP(F.at(pos1).TypeIntersection(targets.reverse()));
}
return t;
}

public JCExpression unannotatedType(boolean allowVar) {
return unannotatedType(allowVar, TYPE);
}
Expand Down Expand Up @@ -1337,15 +1351,7 @@ protected JCExpression term3() {
case CAST:
accept(LPAREN);
selectTypeMode();
int pos1 = pos;
List<JCExpression> targets = List.of(t = parseType());
while (token.kind == AMP) {
accept(AMP);
targets = targets.prepend(parseType());
}
if (targets.length() > 1) {
t = toP(F.at(pos1).TypeIntersection(targets.reverse()));
}
t = parseIntersectionType(pos, parseType());
accept(RPAREN);
selectExprMode();
JCExpression t1 = term3();
Expand Down
107 changes: 83 additions & 24 deletions src/jdk.jshell/share/classes/jdk/jshell/TaskFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@
import com.sun.tools.javac.comp.Enter;
import com.sun.tools.javac.comp.Env;
import com.sun.tools.javac.comp.Resolve;
import com.sun.tools.javac.parser.JavacParser;
import com.sun.tools.javac.parser.Lexer;
import com.sun.tools.javac.parser.Parser;
import com.sun.tools.javac.parser.ParserFactory;
import com.sun.tools.javac.parser.ScannerFactory;
import static com.sun.tools.javac.parser.Tokens.TokenKind.AMP;
import com.sun.tools.javac.tree.JCTree;
import com.sun.tools.javac.tree.JCTree.JCExpression;
import com.sun.tools.javac.tree.JCTree.JCTypeCast;
import com.sun.tools.javac.tree.JCTree.JCVariableDecl;
Expand Down Expand Up @@ -363,7 +368,7 @@ private ParseTask(SourceHandler<String> sh,
JavacTaskImpl task,
DiagnosticCollector<JavaFileObject> diagnostics,
boolean forceExpression) {
super(sh, task, diagnostics);
super(sh, task, diagnostics, false);
ReplParserFactory.preRegister(context, forceExpression);
cuts = parse();
units = Util.stream(cuts)
Expand Down Expand Up @@ -402,7 +407,7 @@ class AnalyzeTask extends BaseTask<OuterWrap> {
private AnalyzeTask(SourceHandler<OuterWrap> sh,
JavacTaskImpl task,
DiagnosticCollector<JavaFileObject> diagnostics) {
super(sh, task, diagnostics);
super(sh, task, diagnostics, true);
cuts = analyze();
}

Expand Down Expand Up @@ -440,7 +445,7 @@ class CompileTask extends BaseTask<OuterWrap> {
CompileTask(SourceHandler<OuterWrap>sh,
JavacTaskImpl jti,
DiagnosticCollector<JavaFileObject> diagnostics) {
super(sh, jti, diagnostics);
super(sh, jti, diagnostics, true);
}

boolean compile() {
Expand Down Expand Up @@ -504,11 +509,15 @@ abstract class BaseTask<S> {

private BaseTask(SourceHandler<S> sh,
JavacTaskImpl task,
DiagnosticCollector<JavaFileObject> diagnostics) {
DiagnosticCollector<JavaFileObject> diagnostics,
boolean analyzeParserFactory) {
this.sourceHandler = sh;
this.task = task;
context = task.getContext();
this.diagnostics = diagnostics;
if (analyzeParserFactory) {
JShellAnalyzeParserFactory.preRegister(context);
}
}

abstract Iterable<? extends CompilationUnitTree> cuTrees();
Expand Down Expand Up @@ -693,7 +702,7 @@ private void setVariableType(VarSnippet s) {
Symtab syms = Symtab.instance(context);
Names names = Names.instance(context);
Log log = Log.instance(context);
ParserFactory parserFactory = ParserFactory.instance(context);
JShellAnalyzeParserFactory parserFactory = (JShellAnalyzeParserFactory) ParserFactory.instance(context);
Attr attr = Attr.instance(context);
Enter enter = Enter.instance(context);
DisableAccessibilityResolve rs = (DisableAccessibilityResolve) Resolve.instance(context);
Expand All @@ -709,26 +718,28 @@ private void setVariableType(VarSnippet s) {
//ignore any errors:
JavaFileObject prev = log.useSource(null);
DiscardDiagnosticHandler h = new DiscardDiagnosticHandler(log);
try {
//parse the type as a cast, i.e. "(<typeName>) x". This is to support
//intersection types:
CharBuffer buf = CharBuffer.wrap(("(" + typeName +")x\u0000").toCharArray(), 0, typeName.length() + 3);
Parser parser = parserFactory.newParser(buf, false, false, false);
JCExpression expr = parser.parseExpression();
if (expr.hasTag(Tag.TYPECAST)) {
//if parsed OK, attribute and set the type:
var2OriginalType.put(field, field.type);

JCTypeCast tree = (JCTypeCast) expr;
rs.runWithoutAccessChecks(() -> {
field.type = attr.attribType(tree.clazz,
enter.getEnvs().iterator().next().enclClass.sym);
});
parserFactory.runPermitIntersectionTypes(() -> {
try {
//parse the type as a cast, i.e. "(<typeName>) x". This is to support
//intersection types:
CharBuffer buf = CharBuffer.wrap(("(" + typeName +")x\u0000").toCharArray(), 0, typeName.length() + 3);
Parser parser = parserFactory.newParser(buf, false, false, false);
JCExpression expr = parser.parseExpression();
if (expr.hasTag(Tag.TYPECAST)) {
//if parsed OK, attribute and set the type:
var2OriginalType.put(field, field.type);

JCTypeCast tree = (JCTypeCast) expr;
rs.runWithoutAccessChecks(() -> {
field.type = attr.attribType(tree.clazz,
enter.getEnvs().iterator().next().enclClass.sym);
});
}
} finally {
log.popDiagnosticHandler(h);
log.useSource(prev);
}
} finally {
log.popDiagnosticHandler(h);
log.useSource(prev);
}
});
}
}
}
Expand Down Expand Up @@ -777,4 +788,52 @@ public boolean isAccessible(Env<AttrContext> env, Type site, Symbol sym, boolean
private static final class Marker {}
}

private static final class JShellAnalyzeParserFactory extends ParserFactory {
public static void preRegister(Context context) {
if (context.get(Marker.class) == null) {
context.put(parserFactoryKey, ((Factory<ParserFactory>) c -> new JShellAnalyzeParserFactory(c)));
context.put(Marker.class, new Marker());
}
}

private final ScannerFactory scannerFactory;
private boolean permitIntersectionTypes;

public JShellAnalyzeParserFactory(Context context) {
super(context);
this.scannerFactory = ScannerFactory.instance(context);
}

/**Run the given Runnable with intersection type permitted.
*
* @param r Runnnable to run
*/
public void runPermitIntersectionTypes(Runnable r) {
boolean prevPermitIntersectionTypes = permitIntersectionTypes;
try {
permitIntersectionTypes = true;
r.run();
} finally {
permitIntersectionTypes = prevPermitIntersectionTypes;
}
}

@Override
public JavacParser newParser(CharSequence input, boolean keepDocComments, boolean keepEndPos, boolean keepLineMap, boolean parseModuleInfo) {
com.sun.tools.javac.parser.Lexer lexer = scannerFactory.newScanner(input, keepDocComments);
return new JavacParser(this, lexer, keepDocComments, keepLineMap, keepEndPos, parseModuleInfo) {
@Override
public JCExpression parseType(boolean allowVar, com.sun.tools.javac.util.List<JCTree.JCAnnotation> annotations) {
int pos = token.pos;
JCExpression t = super.parseType(allowVar, annotations);
if (permitIntersectionTypes) {
t = parseIntersectionType(pos, t);
}
return t;
}
};
}

private static final class Marker {}
}
}
13 changes: 12 additions & 1 deletion test/langtools/jdk/jshell/VariablesTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

/*
* @test
* @bug 8144903 8177466 8191842 8211694 8213725 8239536 8257236 8252409 8294431 8322532
* @bug 8144903 8177466 8191842 8211694 8213725 8239536 8257236 8252409 8294431 8322003 8322532
* @summary Tests for EvaluationState.variables
* @library /tools/lib
* @modules jdk.compiler/com.sun.tools.javac.api
Expand Down Expand Up @@ -627,4 +627,15 @@ public void underscoreAsLambdaParameter() { //JDK-8322532
" int i;", true);
}

public void intersectionTypeAsTypeArgument() { //JDK-8322003
assertEval("interface Shape {}");
assertEval("record Square(int edge) implements Shape {}");
assertEval("record Circle(int radius) implements Shape {}");
assertEval("java.util.function.Consumer<Shape> printShape = System.out::println;");
assertEval("Square square = new Square(1);");
assertEval("Circle circle = new Circle(1);");
assertEval("var shapes = java.util.List.of(square, circle);");
assertEval("shapes.forEach(printShape);");
}

}

1 comment on commit 57a65fe

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.