Skip to content

Commit a05b570

Browse files
committed
[compiler] Allow refs to be lazily initialized during render
Summary: The official guidance for useRef notes an exception to the rule that refs cannot be accessed during render: to avoid recreating the ref's contents, you can test that the ref is uninitialized and then initialize it using an if statement: ``` if (ref.current == null) { ref.current = SomeExpensiveOperation() } ``` The compiler didn't recognize this exception, however, leading to code that obeyed all the official guidance for refs being rejected by the compiler. This PR fixes that, by extending the ref validation machinery with an awareness of guard operations that allow lazy initialization. We now understand `== null` and similar operations, when applied to a ref and consumed by an if terminal, as marking the consequent of the if as a block in which the ref can be safely written to. In order to do so we need to create a notion of ref ids, which link different usages of the same ref via both the ref and the ref value. [ghstack-poisoned]
1 parent 7b7fac0 commit a05b570

19 files changed

+628
-26
lines changed

compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoRefAccesInRender.ts

Lines changed: 162 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import {CompilerError, ErrorSeverity} from '../CompilerError';
99
import {
10+
BlockId,
1011
HIRFunction,
11-
Identifier,
1212
IdentifierId,
1313
Place,
1414
SourceLocation,
@@ -17,6 +17,7 @@ import {
1717
isUseRefType,
1818
} from '../HIR';
1919
import {
20+
eachInstructionOperand,
2021
eachInstructionValueOperand,
2122
eachPatternOperand,
2223
eachTerminalOperand,
@@ -44,11 +45,32 @@ import {Err, Ok, Result} from '../Utils/Result';
4445
* or based on property name alone (`foo.current` might be a ref).
4546
*/
4647

47-
type RefAccessType = {kind: 'None'} | RefAccessRefType;
48+
const opaqueRefId = Symbol();
49+
type RefId = number & {[opaqueRefId]: 'RefId'};
50+
51+
function makeRefId(id: number): RefId {
52+
CompilerError.invariant(id >= 0 && Number.isInteger(id), {
53+
reason: 'Expected identifier id to be a non-negative integer',
54+
description: null,
55+
loc: null,
56+
suggestions: null,
57+
});
58+
return id as RefId;
59+
}
60+
let _refId = 0;
61+
function nextRefId(): RefId {
62+
return makeRefId(_refId++);
63+
}
64+
65+
type RefAccessType =
66+
| {kind: 'None'}
67+
| {kind: 'Nullable'}
68+
| {kind: 'Guard'; refId: RefId}
69+
| RefAccessRefType;
4870

4971
type RefAccessRefType =
50-
| {kind: 'Ref'}
51-
| {kind: 'RefValue'; loc?: SourceLocation}
72+
| {kind: 'Ref'; refId: RefId}
73+
| {kind: 'RefValue'; loc?: SourceLocation; refId?: RefId}
5274
| {kind: 'Structure'; value: null | RefAccessRefType; fn: null | RefFnType};
5375

5476
type RefFnType = {readRefEffect: boolean; returnType: RefAccessType};
@@ -82,11 +104,11 @@ export function validateNoRefAccessInRender(fn: HIRFunction): void {
82104
validateNoRefAccessInRenderImpl(fn, env).unwrap();
83105
}
84106

85-
function refTypeOfType(identifier: Identifier): RefAccessType {
86-
if (isRefValueType(identifier)) {
107+
function refTypeOfType(place: Place): RefAccessType {
108+
if (isRefValueType(place.identifier)) {
87109
return {kind: 'RefValue'};
88-
} else if (isUseRefType(identifier)) {
89-
return {kind: 'Ref'};
110+
} else if (isUseRefType(place.identifier)) {
111+
return {kind: 'Ref', refId: nextRefId()};
90112
} else {
91113
return {kind: 'None'};
92114
}
@@ -101,6 +123,14 @@ function tyEqual(a: RefAccessType, b: RefAccessType): boolean {
101123
return true;
102124
case 'Ref':
103125
return true;
126+
case 'Nullable':
127+
return true;
128+
case 'Guard':
129+
CompilerError.invariant(b.kind === 'Guard', {
130+
reason: 'Expected ref value',
131+
loc: null,
132+
});
133+
return a.refId === b.refId;
104134
case 'RefValue':
105135
CompilerError.invariant(b.kind === 'RefValue', {
106136
reason: 'Expected ref value',
@@ -133,11 +163,17 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
133163
b: RefAccessRefType,
134164
): RefAccessRefType {
135165
if (a.kind === 'RefValue') {
136-
return a;
166+
if (b.kind === 'RefValue' && a.refId === b.refId) {
167+
return a;
168+
}
169+
return {kind: 'RefValue'};
137170
} else if (b.kind === 'RefValue') {
138171
return b;
139172
} else if (a.kind === 'Ref' || b.kind === 'Ref') {
140-
return {kind: 'Ref'};
173+
if (a.kind === 'Ref' && b.kind === 'Ref' && a.refId === b.refId) {
174+
return a;
175+
}
176+
return {kind: 'Ref', refId: nextRefId()};
141177
} else {
142178
CompilerError.invariant(
143179
a.kind === 'Structure' && b.kind === 'Structure',
@@ -178,6 +214,16 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
178214
return b;
179215
} else if (b.kind === 'None') {
180216
return a;
217+
} else if (a.kind === 'Guard' || b.kind === 'Guard') {
218+
if (a.kind === 'Guard' && b.kind === 'Guard' && a.refId === b.refId) {
219+
return a;
220+
}
221+
return {kind: 'None'};
222+
} else if (a.kind === 'Nullable' || b.kind === 'Nullable') {
223+
if (a.kind === 'Nullable' && b.kind === 'Nullable') {
224+
return a;
225+
}
226+
return {kind: 'None'};
181227
} else {
182228
return joinRefAccessRefTypes(a, b);
183229
}
@@ -198,13 +244,14 @@ function validateNoRefAccessInRenderImpl(
198244
} else {
199245
place = param.place;
200246
}
201-
const type = refTypeOfType(place.identifier);
247+
const type = refTypeOfType(place);
202248
env.set(place.identifier.id, type);
203249
}
204250

205251
for (let i = 0; (i == 0 || env.hasChanged()) && i < 10; i++) {
206252
env.resetChanged();
207253
returnValues = [];
254+
const safeBlocks = new Map<BlockId, RefId>();
208255
const errors = new CompilerError();
209256
for (const [, block] of fn.body.blocks) {
210257
for (const phi of block.phis) {
@@ -238,11 +285,15 @@ function validateNoRefAccessInRenderImpl(
238285
if (objType?.kind === 'Structure') {
239286
lookupType = objType.value;
240287
} else if (objType?.kind === 'Ref') {
241-
lookupType = {kind: 'RefValue', loc: instr.loc};
288+
lookupType = {
289+
kind: 'RefValue',
290+
loc: instr.loc,
291+
refId: objType.refId,
292+
};
242293
}
243294
env.set(
244295
instr.lvalue.identifier.id,
245-
lookupType ?? refTypeOfType(instr.lvalue.identifier),
296+
lookupType ?? refTypeOfType(instr.lvalue),
246297
);
247298
break;
248299
}
@@ -251,7 +302,7 @@ function validateNoRefAccessInRenderImpl(
251302
env.set(
252303
instr.lvalue.identifier.id,
253304
env.get(instr.value.place.identifier.id) ??
254-
refTypeOfType(instr.lvalue.identifier),
305+
refTypeOfType(instr.lvalue),
255306
);
256307
break;
257308
}
@@ -260,12 +311,12 @@ function validateNoRefAccessInRenderImpl(
260311
env.set(
261312
instr.value.lvalue.place.identifier.id,
262313
env.get(instr.value.value.identifier.id) ??
263-
refTypeOfType(instr.value.lvalue.place.identifier),
314+
refTypeOfType(instr.value.lvalue.place),
264315
);
265316
env.set(
266317
instr.lvalue.identifier.id,
267318
env.get(instr.value.value.identifier.id) ??
268-
refTypeOfType(instr.lvalue.identifier),
319+
refTypeOfType(instr.lvalue),
269320
);
270321
break;
271322
}
@@ -277,13 +328,10 @@ function validateNoRefAccessInRenderImpl(
277328
}
278329
env.set(
279330
instr.lvalue.identifier.id,
280-
lookupType ?? refTypeOfType(instr.lvalue.identifier),
331+
lookupType ?? refTypeOfType(instr.lvalue),
281332
);
282333
for (const lval of eachPatternOperand(instr.value.lvalue.pattern)) {
283-
env.set(
284-
lval.identifier.id,
285-
lookupType ?? refTypeOfType(lval.identifier),
286-
);
334+
env.set(lval.identifier.id, lookupType ?? refTypeOfType(lval));
287335
}
288336
break;
289337
}
@@ -354,7 +402,11 @@ function validateNoRefAccessInRenderImpl(
354402
types.push(env.get(operand.identifier.id) ?? {kind: 'None'});
355403
}
356404
const value = joinRefAccessTypes(...types);
357-
if (value.kind === 'None') {
405+
if (
406+
value.kind === 'None' ||
407+
value.kind === 'Guard' ||
408+
value.kind === 'Nullable'
409+
) {
358410
env.set(instr.lvalue.identifier.id, {kind: 'None'});
359411
} else {
360412
env.set(instr.lvalue.identifier.id, {
@@ -369,7 +421,18 @@ function validateNoRefAccessInRenderImpl(
369421
case 'PropertyStore':
370422
case 'ComputedDelete':
371423
case 'ComputedStore': {
372-
validateNoRefAccess(errors, env, instr.value.object, instr.loc);
424+
const safe = safeBlocks.get(block.id);
425+
const target = env.get(instr.value.object.identifier.id);
426+
if (
427+
instr.value.kind === 'PropertyStore' &&
428+
safe != null &&
429+
target?.kind === 'Ref' &&
430+
target.refId === safe
431+
) {
432+
safeBlocks.delete(block.id);
433+
} else {
434+
validateNoRefAccess(errors, env, instr.value.object, instr.loc);
435+
}
373436
for (const operand of eachInstructionValueOperand(instr.value)) {
374437
if (operand === instr.value.object) {
375438
continue;
@@ -381,23 +444,67 @@ function validateNoRefAccessInRenderImpl(
381444
case 'StartMemoize':
382445
case 'FinishMemoize':
383446
break;
447+
case 'Primitive': {
448+
if (instr.value.value == null) {
449+
env.set(instr.lvalue.identifier.id, {kind: 'Nullable'});
450+
}
451+
break;
452+
}
453+
case 'BinaryExpression': {
454+
const left = env.get(instr.value.left.identifier.id);
455+
const right = env.get(instr.value.right.identifier.id);
456+
let nullish: boolean = false;
457+
let refId: RefId | null = null;
458+
if (left?.kind === 'RefValue' && left.refId != null) {
459+
refId = left.refId;
460+
} else if (right?.kind === 'RefValue' && right.refId != null) {
461+
refId = right.refId;
462+
}
463+
464+
if (left?.kind === 'Nullable') {
465+
nullish = true;
466+
} else if (right?.kind === 'Nullable') {
467+
nullish = true;
468+
}
469+
470+
if (refId !== null && nullish) {
471+
env.set(instr.lvalue.identifier.id, {kind: 'Guard', refId});
472+
} else {
473+
for (const operand of eachInstructionValueOperand(instr.value)) {
474+
validateNoRefValueAccess(errors, env, operand);
475+
}
476+
}
477+
break;
478+
}
384479
default: {
385480
for (const operand of eachInstructionValueOperand(instr.value)) {
386481
validateNoRefValueAccess(errors, env, operand);
387482
}
388483
break;
389484
}
390485
}
391-
if (isUseRefType(instr.lvalue.identifier)) {
486+
487+
// Guard values are derived from ref.current, so they can only be used in if statement targets
488+
for (const operand of eachInstructionOperand(instr)) {
489+
guardCheck(errors, operand, env);
490+
}
491+
492+
if (
493+
isUseRefType(instr.lvalue.identifier) &&
494+
env.get(instr.lvalue.identifier.id)?.kind !== 'Ref'
495+
) {
392496
env.set(
393497
instr.lvalue.identifier.id,
394498
joinRefAccessTypes(
395499
env.get(instr.lvalue.identifier.id) ?? {kind: 'None'},
396-
{kind: 'Ref'},
500+
{kind: 'Ref', refId: nextRefId()},
397501
),
398502
);
399503
}
400-
if (isRefValueType(instr.lvalue.identifier)) {
504+
if (
505+
isRefValueType(instr.lvalue.identifier) &&
506+
env.get(instr.lvalue.identifier.id)?.kind !== 'RefValue'
507+
) {
401508
env.set(
402509
instr.lvalue.identifier.id,
403510
joinRefAccessTypes(
@@ -407,12 +514,24 @@ function validateNoRefAccessInRenderImpl(
407514
);
408515
}
409516
}
517+
518+
if (block.terminal.kind === 'if') {
519+
const test = env.get(block.terminal.test.identifier.id);
520+
if (test?.kind === 'Guard') {
521+
safeBlocks.set(block.terminal.consequent, test.refId);
522+
}
523+
}
524+
410525
for (const operand of eachTerminalOperand(block.terminal)) {
411526
if (block.terminal.kind !== 'return') {
412527
validateNoRefValueAccess(errors, env, operand);
528+
if (block.terminal.kind !== 'if') {
529+
guardCheck(errors, operand, env);
530+
}
413531
} else {
414532
// Allow functions containing refs to be returned, but not direct ref values
415533
validateNoDirectRefValueAccess(errors, operand, env);
534+
guardCheck(errors, operand, env);
416535
returnValues.push(env.get(operand.identifier.id));
417536
}
418537
}
@@ -444,6 +563,23 @@ function destructure(
444563
return type;
445564
}
446565

566+
function guardCheck(errors: CompilerError, operand: Place, env: Env): void {
567+
if (env.get(operand.identifier.id)?.kind === 'Guard') {
568+
errors.push({
569+
severity: ErrorSeverity.InvalidReact,
570+
reason:
571+
'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)',
572+
loc: operand.loc,
573+
description:
574+
operand.identifier.name !== null &&
575+
operand.identifier.name.kind === 'named'
576+
? `Cannot access ref value \`${operand.identifier.name.value}\``
577+
: null,
578+
suggestions: null,
579+
});
580+
}
581+
}
582+
447583
function validateNoRefValueAccess(
448584
errors: CompilerError,
449585
env: Env,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
2+
## Input
3+
4+
```javascript
5+
//@flow
6+
import {useRef} from 'react';
7+
8+
component C() {
9+
const r = useRef(null);
10+
if (r.current == null) {
11+
r.current = 1;
12+
}
13+
}
14+
15+
export const FIXTURE_ENTRYPOINT = {
16+
fn: C,
17+
params: [{}],
18+
};
19+
20+
```
21+
22+
## Code
23+
24+
```javascript
25+
import { useRef } from "react";
26+
27+
function C() {
28+
const r = useRef(null);
29+
if (r.current == null) {
30+
r.current = 1;
31+
}
32+
}
33+
34+
export const FIXTURE_ENTRYPOINT = {
35+
fn: C,
36+
params: [{}],
37+
};
38+
39+
```
40+
41+
### Eval output
42+
(kind: ok)

0 commit comments

Comments
 (0)