Skip to content

Commit

Permalink
[compiler] Infer return types of function expressions
Browse files Browse the repository at this point in the history
Uses the returnIdentifier added in the previous PR to provide a stable identifier for which we can infer a return type for functions, then wires up the equations in InferTypes to infer the type.

ghstack-source-id: 22c0a9ea096daa5f72821fca2a5ff5b199f65c8b
Pull Request resolved: #30785
  • Loading branch information
josephsavona committed Aug 22, 2024
1 parent 217a0ef commit 8410c8b
Show file tree
Hide file tree
Showing 10 changed files with 74 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ export function lower(
null,
);

const returnIdentifier = builder.makeTemporary(func.node.loc ?? GeneratedSource);
const returnIdentifier = builder.makeTemporary(
func.node.loc ?? GeneratedSource,
);

return Ok({
id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ export function printFunction(fn: HIRFunction): string {
if (definition.length !== 0) {
output.push(definition);
}
output.push(printType(fn.returnIdentifier.type));
output.push(printHIR(fn.body));
output.push(...fn.directives);
return output.join('\n');
Expand Down Expand Up @@ -555,7 +556,10 @@ export function printInstructionValue(instrValue: ReactiveValue): string {
}
})
.join(', ') ?? '';
value = `${kind} ${name} @deps[${deps}] @context[${context}] @effects[${effects}]:\n${fn}`;
const type = printType(
instrValue.loweredFunc.func.returnIdentifier.type,
).trim();
value = `${kind} ${name} @deps[${deps}] @context[${context}] @effects[${effects}]${type !== '' ? ` return${type}` : ''}:\n${fn}`;
break;
}
case 'TaggedTemplateExpression': {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,15 +238,18 @@ function emitSelectorFn(env: Environment, keys: Array<string>): Instruction {
phis: new Set(),
};

const returnIdentifier = createTemporaryPlace(env, GeneratedSource).identifier;
const returnIdentifier = createTemporaryPlace(
env,
GeneratedSource,
).identifier;
const fn: HIRFunction = {
loc: GeneratedSource,
id: null,
fnType: 'Other',
env,
params: [obj],
returnType: null,
returnIdentifier,
returnIdentifier,
context: [],
effects: null,
body: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -481,14 +481,20 @@ function canMergeScopes(
}

function isAlwaysInvalidatingType(type: Type): boolean {
if (type.kind === 'Object') {
switch (type.shapeId) {
case BuiltInArrayId:
case BuiltInObjectId:
case BuiltInFunctionId:
case BuiltInJsxId: {
return true;
switch (type.kind) {
case 'Object': {
switch (type.shapeId) {
case BuiltInArrayId:
case BuiltInObjectId:
case BuiltInFunctionId:
case BuiltInJsxId: {
return true;
}
}
break;
}
case 'Function': {
return true;
}
}
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ function apply(func: HIRFunction, unifier: Unifier): void {
}
}
}
func.returnIdentifier.type = unifier.get(func.returnIdentifier.type);
}

type TypeEquation = {
Expand Down Expand Up @@ -122,6 +123,7 @@ function* generate(
}

const names = new Map();
const returnTypes: Array<Type> = [];
for (const [_, block] of func.body.blocks) {
for (const phi of block.phis) {
yield equation(phi.type, {
Expand All @@ -133,6 +135,18 @@ function* generate(
for (const instr of block.instructions) {
yield* generateInstructionTypes(func.env, names, instr);
}
const terminal = block.terminal;
if (terminal.kind === 'return') {
returnTypes.push(terminal.value.identifier.type);
}
}
if (returnTypes.length > 1) {
yield equation(func.returnIdentifier.type, {
kind: 'Phi',
operands: returnTypes,
});
} else if (returnTypes.length === 1) {
yield equation(func.returnIdentifier.type, returnTypes[0]!);
}
}

Expand Down Expand Up @@ -346,7 +360,11 @@ function* generateInstructionTypes(

case 'FunctionExpression': {
yield* generate(value.loweredFunc.func);
yield equation(left, {kind: 'Object', shapeId: BuiltInFunctionId});
yield equation(left, {
kind: 'Function',
shapeId: BuiltInFunctionId,
return: value.loweredFunc.func.returnIdentifier.type,
});
break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ component Component() {
if (data != null) {
return true;
} else {
return false;
return {};
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ component Component() {
if (data != null) {
return true;
} else {
return false;
return {};
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ export const FIXTURE_ENTRYPOINT = {
import { c as _c } from "react/compiler-runtime";
function hoisting() {
const $ = _c(1);
let t0;
let foo;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
const foo = () => bar + baz;
foo = () => bar + baz;

const bar = 3;
const baz = 2;
t0 = foo();
$[0] = t0;
$[0] = foo;
} else {
t0 = $[0];
foo = $[0];
}
return t0;
return foo();
}

export const FIXTURE_ENTRYPOINT = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ export const FIXTURE_ENTRYPOINT = {
import { c as _c } from "react/compiler-runtime";
function hoisting() {
const $ = _c(1);
let t0;
let foo;
if ($[0] === Symbol.for("react.memo_cache_sentinel")) {
const foo = () => bar + baz;
foo = () => bar + baz;

let bar = 3;
let baz = 2;
t0 = foo();
$[0] = t0;
$[0] = foo;
} else {
t0 = $[0];
foo = $[0];
}
return t0;
return foo();
}

export const FIXTURE_ENTRYPOINT = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import { c as _c } from "react/compiler-runtime"; // @validatePreserveExistingMe
import { useCallback } from "react";

function Component(t0) {
const $ = _c(11);
const $ = _c(9);
const { entity, children } = t0;
let t1;
if ($[0] !== entity) {
Expand All @@ -51,46 +51,39 @@ function Component(t0) {
t1 = $[1];
}
const showMessage = t1;

const shouldShowMessage = showMessage();
let t2;
if ($[2] !== showMessage) {
t2 = showMessage();
$[2] = showMessage;
if ($[2] !== shouldShowMessage) {
t2 = <div>{shouldShowMessage}</div>;
$[2] = shouldShowMessage;
$[3] = t2;
} else {
t2 = $[3];
}
const shouldShowMessage = t2;
let t3;
if ($[4] !== shouldShowMessage) {
t3 = <div>{shouldShowMessage}</div>;
$[4] = shouldShowMessage;
if ($[4] !== children) {
t3 = <div>{children}</div>;
$[4] = children;
$[5] = t3;
} else {
t3 = $[5];
}
let t4;
if ($[6] !== children) {
t4 = <div>{children}</div>;
$[6] = children;
$[7] = t4;
} else {
t4 = $[7];
}
let t5;
if ($[8] !== t3 || $[9] !== t4) {
t5 = (
if ($[6] !== t2 || $[7] !== t3) {
t4 = (
<div>
{t2}
{t3}
{t4}
</div>
);
$[8] = t3;
$[9] = t4;
$[10] = t5;
$[6] = t2;
$[7] = t3;
$[8] = t4;
} else {
t5 = $[10];
t4 = $[8];
}
return t5;
return t4;
}

export const FIXTURE_ENTRYPOINT = {
Expand Down

0 comments on commit 8410c8b

Please sign in to comment.