7
7
8
8
import { CompilerError , ErrorSeverity } from '../CompilerError' ;
9
9
import {
10
+ BlockId ,
10
11
HIRFunction ,
11
- Identifier ,
12
12
IdentifierId ,
13
13
Place ,
14
14
SourceLocation ,
@@ -17,6 +17,7 @@ import {
17
17
isUseRefType ,
18
18
} from '../HIR' ;
19
19
import {
20
+ eachInstructionOperand ,
20
21
eachInstructionValueOperand ,
21
22
eachPatternOperand ,
22
23
eachTerminalOperand ,
@@ -44,11 +45,32 @@ import {Err, Ok, Result} from '../Utils/Result';
44
45
* or based on property name alone (`foo.current` might be a ref).
45
46
*/
46
47
47
- type RefAccessType = { kind : 'None' } | RefAccessRefType ;
48
+ const opaqueRefId = Symbol ( ) ;
49
+ type RefId = number & { [ opaqueRefId ] : 'RefId' } ;
50
+
51
+ function makeRefId ( id : number ) : RefId {
52
+ CompilerError . invariant ( id >= 0 && Number . isInteger ( id ) , {
53
+ reason : 'Expected identifier id to be a non-negative integer' ,
54
+ description : null ,
55
+ loc : null ,
56
+ suggestions : null ,
57
+ } ) ;
58
+ return id as RefId ;
59
+ }
60
+ let _refId = 0 ;
61
+ function nextRefId ( ) : RefId {
62
+ return makeRefId ( _refId ++ ) ;
63
+ }
64
+
65
+ type RefAccessType =
66
+ | { kind : 'None' }
67
+ | { kind : 'Nullable' }
68
+ | { kind : 'Guard' ; refId : RefId }
69
+ | RefAccessRefType ;
48
70
49
71
type RefAccessRefType =
50
- | { kind : 'Ref' }
51
- | { kind : 'RefValue' ; loc ?: SourceLocation }
72
+ | { kind : 'Ref' ; refId : RefId }
73
+ | { kind : 'RefValue' ; loc ?: SourceLocation ; refId ?: RefId }
52
74
| { kind : 'Structure' ; value : null | RefAccessRefType ; fn : null | RefFnType } ;
53
75
54
76
type RefFnType = { readRefEffect : boolean ; returnType : RefAccessType } ;
@@ -82,11 +104,11 @@ export function validateNoRefAccessInRender(fn: HIRFunction): void {
82
104
validateNoRefAccessInRenderImpl ( fn , env ) . unwrap ( ) ;
83
105
}
84
106
85
- function refTypeOfType ( identifier : Identifier ) : RefAccessType {
86
- if ( isRefValueType ( identifier ) ) {
107
+ function refTypeOfType ( place : Place ) : RefAccessType {
108
+ if ( isRefValueType ( place . identifier ) ) {
87
109
return { kind : 'RefValue' } ;
88
- } else if ( isUseRefType ( identifier ) ) {
89
- return { kind : 'Ref' } ;
110
+ } else if ( isUseRefType ( place . identifier ) ) {
111
+ return { kind : 'Ref' , refId : nextRefId ( ) } ;
90
112
} else {
91
113
return { kind : 'None' } ;
92
114
}
@@ -101,6 +123,14 @@ function tyEqual(a: RefAccessType, b: RefAccessType): boolean {
101
123
return true ;
102
124
case 'Ref' :
103
125
return true ;
126
+ case 'Nullable' :
127
+ return true ;
128
+ case 'Guard' :
129
+ CompilerError . invariant ( b . kind === 'Guard' , {
130
+ reason : 'Expected ref value' ,
131
+ loc : null ,
132
+ } ) ;
133
+ return a . refId === b . refId ;
104
134
case 'RefValue' :
105
135
CompilerError . invariant ( b . kind === 'RefValue' , {
106
136
reason : 'Expected ref value' ,
@@ -133,11 +163,17 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
133
163
b : RefAccessRefType ,
134
164
) : RefAccessRefType {
135
165
if ( a . kind === 'RefValue' ) {
136
- return a ;
166
+ if ( b . kind === 'RefValue' && a . refId === b . refId ) {
167
+ return a ;
168
+ }
169
+ return { kind : 'RefValue' } ;
137
170
} else if ( b . kind === 'RefValue' ) {
138
171
return b ;
139
172
} else if ( a . kind === 'Ref' || b . kind === 'Ref' ) {
140
- return { kind : 'Ref' } ;
173
+ if ( a . kind === 'Ref' && b . kind === 'Ref' && a . refId === b . refId ) {
174
+ return a ;
175
+ }
176
+ return { kind : 'Ref' , refId : nextRefId ( ) } ;
141
177
} else {
142
178
CompilerError . invariant (
143
179
a . kind === 'Structure' && b . kind === 'Structure' ,
@@ -178,6 +214,16 @@ function joinRefAccessTypes(...types: Array<RefAccessType>): RefAccessType {
178
214
return b ;
179
215
} else if ( b . kind === 'None' ) {
180
216
return a ;
217
+ } else if ( a . kind === 'Guard' || b . kind === 'Guard' ) {
218
+ if ( a . kind === 'Guard' && b . kind === 'Guard' && a . refId === b . refId ) {
219
+ return a ;
220
+ }
221
+ return { kind : 'None' } ;
222
+ } else if ( a . kind === 'Nullable' || b . kind === 'Nullable' ) {
223
+ if ( a . kind === 'Nullable' && b . kind === 'Nullable' ) {
224
+ return a ;
225
+ }
226
+ return { kind : 'None' } ;
181
227
} else {
182
228
return joinRefAccessRefTypes ( a , b ) ;
183
229
}
@@ -198,13 +244,14 @@ function validateNoRefAccessInRenderImpl(
198
244
} else {
199
245
place = param . place ;
200
246
}
201
- const type = refTypeOfType ( place . identifier ) ;
247
+ const type = refTypeOfType ( place ) ;
202
248
env . set ( place . identifier . id , type ) ;
203
249
}
204
250
205
251
for ( let i = 0 ; ( i == 0 || env . hasChanged ( ) ) && i < 10 ; i ++ ) {
206
252
env . resetChanged ( ) ;
207
253
returnValues = [ ] ;
254
+ const safeBlocks = new Map < BlockId , RefId > ( ) ;
208
255
const errors = new CompilerError ( ) ;
209
256
for ( const [ , block ] of fn . body . blocks ) {
210
257
for ( const phi of block . phis ) {
@@ -238,11 +285,15 @@ function validateNoRefAccessInRenderImpl(
238
285
if ( objType ?. kind === 'Structure' ) {
239
286
lookupType = objType . value ;
240
287
} else if ( objType ?. kind === 'Ref' ) {
241
- lookupType = { kind : 'RefValue' , loc : instr . loc } ;
288
+ lookupType = {
289
+ kind : 'RefValue' ,
290
+ loc : instr . loc ,
291
+ refId : objType . refId ,
292
+ } ;
242
293
}
243
294
env . set (
244
295
instr . lvalue . identifier . id ,
245
- lookupType ?? refTypeOfType ( instr . lvalue . identifier ) ,
296
+ lookupType ?? refTypeOfType ( instr . lvalue ) ,
246
297
) ;
247
298
break ;
248
299
}
@@ -251,7 +302,7 @@ function validateNoRefAccessInRenderImpl(
251
302
env . set (
252
303
instr . lvalue . identifier . id ,
253
304
env . get ( instr . value . place . identifier . id ) ??
254
- refTypeOfType ( instr . lvalue . identifier ) ,
305
+ refTypeOfType ( instr . lvalue ) ,
255
306
) ;
256
307
break ;
257
308
}
@@ -260,12 +311,12 @@ function validateNoRefAccessInRenderImpl(
260
311
env . set (
261
312
instr . value . lvalue . place . identifier . id ,
262
313
env . get ( instr . value . value . identifier . id ) ??
263
- refTypeOfType ( instr . value . lvalue . place . identifier ) ,
314
+ refTypeOfType ( instr . value . lvalue . place ) ,
264
315
) ;
265
316
env . set (
266
317
instr . lvalue . identifier . id ,
267
318
env . get ( instr . value . value . identifier . id ) ??
268
- refTypeOfType ( instr . lvalue . identifier ) ,
319
+ refTypeOfType ( instr . lvalue ) ,
269
320
) ;
270
321
break ;
271
322
}
@@ -277,13 +328,10 @@ function validateNoRefAccessInRenderImpl(
277
328
}
278
329
env . set (
279
330
instr . lvalue . identifier . id ,
280
- lookupType ?? refTypeOfType ( instr . lvalue . identifier ) ,
331
+ lookupType ?? refTypeOfType ( instr . lvalue ) ,
281
332
) ;
282
333
for ( const lval of eachPatternOperand ( instr . value . lvalue . pattern ) ) {
283
- env . set (
284
- lval . identifier . id ,
285
- lookupType ?? refTypeOfType ( lval . identifier ) ,
286
- ) ;
334
+ env . set ( lval . identifier . id , lookupType ?? refTypeOfType ( lval ) ) ;
287
335
}
288
336
break ;
289
337
}
@@ -354,7 +402,11 @@ function validateNoRefAccessInRenderImpl(
354
402
types . push ( env . get ( operand . identifier . id ) ?? { kind : 'None' } ) ;
355
403
}
356
404
const value = joinRefAccessTypes ( ...types ) ;
357
- if ( value . kind === 'None' ) {
405
+ if (
406
+ value . kind === 'None' ||
407
+ value . kind === 'Guard' ||
408
+ value . kind === 'Nullable'
409
+ ) {
358
410
env . set ( instr . lvalue . identifier . id , { kind : 'None' } ) ;
359
411
} else {
360
412
env . set ( instr . lvalue . identifier . id , {
@@ -369,7 +421,18 @@ function validateNoRefAccessInRenderImpl(
369
421
case 'PropertyStore' :
370
422
case 'ComputedDelete' :
371
423
case 'ComputedStore' : {
372
- validateNoRefAccess ( errors , env , instr . value . object , instr . loc ) ;
424
+ const safe = safeBlocks . get ( block . id ) ;
425
+ const target = env . get ( instr . value . object . identifier . id ) ;
426
+ if (
427
+ instr . value . kind === 'PropertyStore' &&
428
+ safe != null &&
429
+ target ?. kind === 'Ref' &&
430
+ target . refId === safe
431
+ ) {
432
+ safeBlocks . delete ( block . id ) ;
433
+ } else {
434
+ validateNoRefAccess ( errors , env , instr . value . object , instr . loc ) ;
435
+ }
373
436
for ( const operand of eachInstructionValueOperand ( instr . value ) ) {
374
437
if ( operand === instr . value . object ) {
375
438
continue ;
@@ -381,23 +444,67 @@ function validateNoRefAccessInRenderImpl(
381
444
case 'StartMemoize' :
382
445
case 'FinishMemoize' :
383
446
break ;
447
+ case 'Primitive' : {
448
+ if ( instr . value . value == null ) {
449
+ env . set ( instr . lvalue . identifier . id , { kind : 'Nullable' } ) ;
450
+ }
451
+ break ;
452
+ }
453
+ case 'BinaryExpression' : {
454
+ const left = env . get ( instr . value . left . identifier . id ) ;
455
+ const right = env . get ( instr . value . right . identifier . id ) ;
456
+ let nullish : boolean = false ;
457
+ let refId : RefId | null = null ;
458
+ if ( left ?. kind === 'RefValue' && left . refId != null ) {
459
+ refId = left . refId ;
460
+ } else if ( right ?. kind === 'RefValue' && right . refId != null ) {
461
+ refId = right . refId ;
462
+ }
463
+
464
+ if ( left ?. kind === 'Nullable' ) {
465
+ nullish = true ;
466
+ } else if ( right ?. kind === 'Nullable' ) {
467
+ nullish = true ;
468
+ }
469
+
470
+ if ( refId !== null && nullish ) {
471
+ env . set ( instr . lvalue . identifier . id , { kind : 'Guard' , refId} ) ;
472
+ } else {
473
+ for ( const operand of eachInstructionValueOperand ( instr . value ) ) {
474
+ validateNoRefValueAccess ( errors , env , operand ) ;
475
+ }
476
+ }
477
+ break ;
478
+ }
384
479
default : {
385
480
for ( const operand of eachInstructionValueOperand ( instr . value ) ) {
386
481
validateNoRefValueAccess ( errors , env , operand ) ;
387
482
}
388
483
break ;
389
484
}
390
485
}
391
- if ( isUseRefType ( instr . lvalue . identifier ) ) {
486
+
487
+ // Guard values are derived from ref.current, so they can only be used in if statement targets
488
+ for ( const operand of eachInstructionOperand ( instr ) ) {
489
+ guardCheck ( errors , operand , env ) ;
490
+ }
491
+
492
+ if (
493
+ isUseRefType ( instr . lvalue . identifier ) &&
494
+ env . get ( instr . lvalue . identifier . id ) ?. kind !== 'Ref'
495
+ ) {
392
496
env . set (
393
497
instr . lvalue . identifier . id ,
394
498
joinRefAccessTypes (
395
499
env . get ( instr . lvalue . identifier . id ) ?? { kind : 'None' } ,
396
- { kind : 'Ref' } ,
500
+ { kind : 'Ref' , refId : nextRefId ( ) } ,
397
501
) ,
398
502
) ;
399
503
}
400
- if ( isRefValueType ( instr . lvalue . identifier ) ) {
504
+ if (
505
+ isRefValueType ( instr . lvalue . identifier ) &&
506
+ env . get ( instr . lvalue . identifier . id ) ?. kind !== 'RefValue'
507
+ ) {
401
508
env . set (
402
509
instr . lvalue . identifier . id ,
403
510
joinRefAccessTypes (
@@ -407,12 +514,24 @@ function validateNoRefAccessInRenderImpl(
407
514
) ;
408
515
}
409
516
}
517
+
518
+ if ( block . terminal . kind === 'if' ) {
519
+ const test = env . get ( block . terminal . test . identifier . id ) ;
520
+ if ( test ?. kind === 'Guard' ) {
521
+ safeBlocks . set ( block . terminal . consequent , test . refId ) ;
522
+ }
523
+ }
524
+
410
525
for ( const operand of eachTerminalOperand ( block . terminal ) ) {
411
526
if ( block . terminal . kind !== 'return' ) {
412
527
validateNoRefValueAccess ( errors , env , operand ) ;
528
+ if ( block . terminal . kind !== 'if' ) {
529
+ guardCheck ( errors , operand , env ) ;
530
+ }
413
531
} else {
414
532
// Allow functions containing refs to be returned, but not direct ref values
415
533
validateNoDirectRefValueAccess ( errors , operand , env ) ;
534
+ guardCheck ( errors , operand , env ) ;
416
535
returnValues . push ( env . get ( operand . identifier . id ) ) ;
417
536
}
418
537
}
@@ -444,6 +563,23 @@ function destructure(
444
563
return type ;
445
564
}
446
565
566
+ function guardCheck ( errors : CompilerError , operand : Place , env : Env ) : void {
567
+ if ( env . get ( operand . identifier . id ) ?. kind === 'Guard' ) {
568
+ errors . push ( {
569
+ severity : ErrorSeverity . InvalidReact ,
570
+ reason :
571
+ 'Ref values (the `current` property) may not be accessed during render. (https://react.dev/reference/react/useRef)' ,
572
+ loc : operand . loc ,
573
+ description :
574
+ operand . identifier . name !== null &&
575
+ operand . identifier . name . kind === 'named'
576
+ ? `Cannot access ref value \`${ operand . identifier . name . value } \``
577
+ : null ,
578
+ suggestions : null ,
579
+ } ) ;
580
+ }
581
+ }
582
+
447
583
function validateNoRefValueAccess (
448
584
errors : CompilerError ,
449
585
env : Env ,
0 commit comments