diff --git a/packages/react-debug-tools/src/ReactDebugHooks.js b/packages/react-debug-tools/src/ReactDebugHooks.js index 1ecc02aa3e652..f7815a36f2893 100644 --- a/packages/react-debug-tools/src/ReactDebugHooks.js +++ b/packages/react-debug-tools/src/ReactDebugHooks.js @@ -37,6 +37,7 @@ import { REACT_CONTEXT_TYPE, } from 'shared/ReactSymbols'; import hasOwnProperty from 'shared/hasOwnProperty'; +import type {ContextDependencyWithCompare} from '../../react-reconciler/src/ReactInternalTypes'; type CurrentDispatcherRef = typeof ReactSharedInternals; @@ -155,7 +156,10 @@ function getPrimitiveStackCache(): Map> { let currentFiber: null | Fiber = null; let currentHook: null | Hook = null; -let currentContextDependency: null | ContextDependency = null; +let currentContextDependency: + | null + | ContextDependency + | ContextDependencyWithCompare = null; function nextHook(): null | Hook { const hook = currentHook; diff --git a/packages/react-reconciler/src/ReactFiberHooks.js b/packages/react-reconciler/src/ReactFiberHooks.js index a2567bd64118e..27c91bf84c41a 100644 --- a/packages/react-reconciler/src/ReactFiberHooks.js +++ b/packages/react-reconciler/src/ReactFiberHooks.js @@ -1060,7 +1060,7 @@ function updateWorkInProgressHook(): Hook { function unstable_useContextWithBailout( context: ReactContext, - compare: void | (T => mixed), + compare: (T => mixed) | null, ): T { return readContextAndCompare(context, compare); } @@ -4049,7 +4049,7 @@ if (__DEV__) { } if (enableContextProfiling) { (HooksDispatcherOnMountInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; mountHookTypesDev(); return unstable_useContextWithBailout(context, compare); @@ -4238,7 +4238,7 @@ if (__DEV__) { } if (enableContextProfiling) { (HooksDispatcherOnMountWithHookTypesInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; updateHookTypesDev(); return unstable_useContextWithBailout(context, compare); @@ -4426,7 +4426,7 @@ if (__DEV__) { } if (enableContextProfiling) { (HooksDispatcherOnUpdateInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; updateHookTypesDev(); return unstable_useContextWithBailout(context, compare); @@ -4614,7 +4614,7 @@ if (__DEV__) { } if (enableContextProfiling) { (HooksDispatcherOnUpdateInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; updateHookTypesDev(); return unstable_useContextWithBailout(context, compare); @@ -4828,7 +4828,7 @@ if (__DEV__) { } if (enableContextProfiling) { (HooksDispatcherOnUpdateInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; warnInvalidHookAccess(); mountHookTypesDev(); @@ -5043,7 +5043,7 @@ if (__DEV__) { } if (enableContextProfiling) { (InvalidNestedHooksDispatcherOnUpdateInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; warnInvalidHookAccess(); updateHookTypesDev(); @@ -5258,7 +5258,7 @@ if (__DEV__) { } if (enableContextProfiling) { (InvalidNestedHooksDispatcherOnRerenderInDEV: Dispatcher).unstable_useContextWithBailout = - function (context: ReactContext, compare: void | (T => mixed)): T { + function (context: ReactContext, compare: (T => mixed) | null): T { currentHookNameInDev = 'useContext'; warnInvalidHookAccess(); updateHookTypesDev(); diff --git a/packages/react-reconciler/src/ReactFiberNewContext.js b/packages/react-reconciler/src/ReactFiberNewContext.js index bec21a08686fe..404295b4e0d65 100644 --- a/packages/react-reconciler/src/ReactFiberNewContext.js +++ b/packages/react-reconciler/src/ReactFiberNewContext.js @@ -12,6 +12,7 @@ import type { Fiber, ContextDependency, Dependencies, + ContextDependencyWithCompare, } from './ReactInternalTypes'; import type {StackCursor} from './ReactFiberStack'; import type {Lanes} from './ReactFiberLane'; @@ -72,7 +73,10 @@ if (__DEV__) { } let currentlyRenderingFiber: Fiber | null = null; -let lastContextDependency: ContextDependency | null = null; +let lastContextDependency: + | ContextDependency + | ContextDependencyWithCompare + | null = null; let lastFullyObservedContext: ReactContext | null = null; let isDisallowedContextReadInDEV: boolean = false; @@ -403,19 +407,21 @@ function propagateContextChanges( const context: ReactContext = contexts[i]; // Check if the context matches. if (dependency.context === context) { - const compare = dependency.compare; - if (enableContextProfiling && compare != null) { - const newValue = isPrimaryRenderer - ? dependency.context._currentValue - : dependency.context._currentValue2; - if ( - !checkIfComparedContextValuesChanged( - dependency.lastComparedValue, - compare(newValue), - ) - ) { - // Compared value hasn't changed. Bail out early. - continue findContext; + if (enableContextProfiling) { + const compare = dependency.compare; + if (compare != null) { + const newValue = isPrimaryRenderer + ? dependency.context._currentValue + : dependency.context._currentValue2; + if ( + !checkIfComparedContextValuesChanged( + dependency.lastComparedValue, + compare(newValue), + ) + ) { + // Compared value hasn't changed. Bail out early. + continue findContext; + } } } // Match! Schedule an update on this fiber. @@ -746,13 +752,17 @@ export function prepareToReadContext( export function readContextAndCompare( context: ReactContext, - compare: void | (C => mixed), + compare: (C => mixed) | null, ): C { if (!enableLazyContextPropagation) { return readContext(context); } - return readContextForConsumer(currentlyRenderingFiber, context, compare); + return readContextForConsumer_withCompare( + currentlyRenderingFiber, + context, + compare, + ); } export function readContext(context: ReactContext): T { @@ -782,12 +792,12 @@ export function readContextDuringReconciliation( return readContextForConsumer(consumer, context); } -type ContextCompare = C => S; +type ContextCompare = C => V | null; -function readContextForConsumer( +function readContextForConsumer_withCompare( consumer: Fiber | null, context: ReactContext, - compare?: void | (C => S), + compare: (C => S) | null, ): C { const value = isPrimaryRenderer ? context._currentValue @@ -800,7 +810,7 @@ function readContextForConsumer( context: ((context: any): ReactContext), memoizedValue: value, next: null, - compare: ((compare: any): ContextCompare | null), + compare: compare ? ((compare: any): ContextCompare) : null, lastComparedValue: compare != null ? compare(value) : null, }; @@ -830,3 +840,47 @@ function readContextForConsumer( } return value; } + +function readContextForConsumer( + consumer: Fiber | null, + context: ReactContext, +): C { + const value = isPrimaryRenderer + ? context._currentValue + : context._currentValue2; + + if (lastFullyObservedContext === context) { + // Nothing to do. We already observe everything in this context. + } else { + const contextItem = { + context: ((context: any): ReactContext), + memoizedValue: value, + next: null, + }; + + if (lastContextDependency === null) { + if (consumer === null) { + throw new Error( + 'Context can only be read while React is rendering. ' + + 'In classes, you can read it in the render method or getDerivedStateFromProps. ' + + 'In function components, you can read it directly in the function body, but not ' + + 'inside Hooks like useReducer() or useMemo().', + ); + } + + // This is the first dependency for this component. Create a new list. + lastContextDependency = contextItem; + consumer.dependencies = { + lanes: NoLanes, + firstContext: contextItem, + }; + if (enableLazyContextPropagation) { + consumer.flags |= NeedsPropagation; + } + } else { + // Append a new context item. + lastContextDependency = lastContextDependency.next = contextItem; + } + } + return value; +} diff --git a/packages/react-reconciler/src/ReactInternalTypes.js b/packages/react-reconciler/src/ReactInternalTypes.js index 67c260d78a0e9..1e033e8395a81 100644 --- a/packages/react-reconciler/src/ReactInternalTypes.js +++ b/packages/react-reconciler/src/ReactInternalTypes.js @@ -61,18 +61,32 @@ export type HookType = | 'useFormState' | 'useActionState'; -export type ContextDependency = { +export type ContextDependency = { context: ReactContext, - next: ContextDependency | null, + next: + | ContextDependency + | ContextDependencyWithCompare + | null, + memoizedValue: C, +}; + +export type ContextDependencyWithCompare = { + context: ReactContext, + next: + | ContextDependency + | ContextDependencyWithCompare + | null, memoizedValue: C, compare: (C => S) | null, - lastComparedValue: S | null, - ... + lastComparedValue?: S | null, }; export type Dependencies = { lanes: Lanes, - firstContext: ContextDependency | null, + firstContext: + | ContextDependency + | ContextDependencyWithCompare + | null, ... }; @@ -388,7 +402,7 @@ export type Dispatcher = { ): [S, Dispatch], unstable_useContextWithBailout?: ( context: ReactContext, - compare: void | (T => mixed), + compare: (T => mixed) | null, ) => T, useContext(context: ReactContext): T, useRef(initialValue: T): {current: T}, diff --git a/packages/react/src/ReactHooks.js b/packages/react/src/ReactHooks.js index 697752ef9698e..f9e243c397519 100644 --- a/packages/react/src/ReactHooks.js +++ b/packages/react/src/ReactHooks.js @@ -71,7 +71,7 @@ export function useContext(Context: ReactContext): T { export function unstable_useContextWithBailout( context: ReactContext, - compare: void | (T => mixed), + compare: (T => mixed) | null, ): T { if (!(enableLazyContextPropagation && enableContextProfiling)) { throw new Error('Not implemented.');