diff --git a/packages/router/src/__tests__/routeScrollReset.test.tsx b/packages/router/src/__tests__/routeScrollReset.test.tsx new file mode 100644 index 000000000000..bf9475f62b0e --- /dev/null +++ b/packages/router/src/__tests__/routeScrollReset.test.tsx @@ -0,0 +1,93 @@ +import React from 'react' + +import '@testing-library/jest-dom/extend-expect' +import { act, cleanup, render, screen } from '@testing-library/react' + +import { navigate } from '../history' +import { Route, Router, routes } from '../router' + +describe('Router scroll reset', () => { + const Page1 = () =>
Page 1
+ const Page2 = () =>
Page 2
+ const TestRouter = () => ( + + + + + ) + + // Redfine the mocks here again (already done in jest.setup) + // Otherwise the mock doesn't clear for some reason + globalThis.scrollTo = jest.fn() + + beforeEach(async () => { + ;(globalThis.scrollTo as jest.Mock).mockClear() + render() + + // Make sure we're starting on the home route + await screen.getByText('Page 1') + }) + + afterEach(async () => { + // @NOTE: for some reason, the Router state does not reset between renders + act(() => navigate('/')) + cleanup() + }) + + it('resets on location/path change', async () => { + act(() => + navigate( + // @ts-expect-error - AvailableRoutes built in project only + routes.page2() + ) + ) + + await screen.getByText('Page 2') + + expect(globalThis.scrollTo).toHaveBeenCalledTimes(1) + }) + + it('resets on location/path and queryChange change', async () => { + act(() => + navigate( + // @ts-expect-error - AvailableRoutes built in project only + routes.page2({ + tab: 'three', + }) + ) + ) + + await screen.getByText('Page 2') + + expect(globalThis.scrollTo).toHaveBeenCalledTimes(1) + }) + + it('resets scroll on query params (search) change on the same page', async () => { + act(() => + // We're staying on page 1, but changing the query params + navigate( + // @ts-expect-error - AvailableRoutes built in project only + routes.page1({ + queryParam1: 'foo', + }) + ) + ) + + await screen.getByText('Page 1') + + expect(globalThis.scrollTo).toHaveBeenCalledTimes(1) + }) + + it('does NOT reset on hash change', async () => { + await screen.getByText('Page 1') + + act(() => + // Stay on page 1, but change the hash + navigate(`#route=66`, { replace: true }) + ) + + await screen.getByText('Page 1') + + expect(globalThis.scrollTo).not.toHaveBeenCalled() + }) +}) diff --git a/packages/router/src/active-route-loader.tsx b/packages/router/src/active-route-loader.tsx index 61e6cf6190eb..cacbee622dd1 100644 --- a/packages/router/src/active-route-loader.tsx +++ b/packages/router/src/active-route-loader.tsx @@ -70,8 +70,6 @@ export const ActiveRouteLoader = ({ return } - globalThis?.scrollTo(0, 0) - if (announcementRef.current) { announcementRef.current.innerText = getAnnouncement() } diff --git a/packages/router/src/location.tsx b/packages/router/src/location.tsx index bccad5cdff37..4719abeaebeb 100644 --- a/packages/router/src/location.tsx +++ b/packages/router/src/location.tsx @@ -11,24 +11,32 @@ export interface LocationContextType { const LocationContext = createNamedContext('Location') +interface Location { + pathname: string + search?: string + hash?: string +} interface LocationProviderProps { - location?: { - pathname: string - search?: string - hash?: string - } + location?: Location trailingSlashes?: TrailingSlashesTypes children?: React.ReactNode } -class LocationProvider extends React.Component { +interface LocationProviderState { + context: Location +} + +class LocationProvider extends React.Component< + LocationProviderProps, + LocationProviderState +> { // When prerendering, there might be more than one level of location // providers. Use the values from the one above. static contextType = LocationContext declare context: React.ContextType HISTORY_LISTENER_ID: string | undefined = undefined - state = { + state: LocationProviderState = { context: this.getContext(), } @@ -82,7 +90,17 @@ class LocationProvider extends React.Component { componentDidMount() { this.HISTORY_LISTENER_ID = gHistory.listen(() => { - this.setState(() => ({ context: this.getContext() })) + const context = this.getContext() + this.setState((lastState) => { + if ( + context.pathname !== lastState.context.pathname || + context.search !== lastState.context.search + ) { + globalThis?.scrollTo(0, 0) + } + + return { context } + }) }) } diff --git a/packages/router/src/util.ts b/packages/router/src/util.ts index ebb2188e2fcf..5b22dcf56e5d 100644 --- a/packages/router/src/util.ts +++ b/packages/router/src/util.ts @@ -122,7 +122,7 @@ type SupportedRouterParamTypes = keyof typeof coreParamTypes * => { match: true, params: {} } */ export function matchPath( - route: string, + routeDefinition: string, pathname: string, { userParamTypes, @@ -138,10 +138,11 @@ export function matchPath( // Get the names and the transform types for the given route. const allParamTypes = { ...coreParamTypes, ...userParamTypes } - const { matchRegex, routeParams } = getRouteRegexAndParams(route, { - matchSubPaths, - allParamTypes, - }) + const { matchRegex, routeParams: routeParamsDefinition } = + getRouteRegexAndParams(routeDefinition, { + matchSubPaths, + allParamTypes, + }) // Does the `pathname` match the route? const matches = [...pathname.matchAll(matchRegex)] @@ -151,10 +152,12 @@ export function matchPath( } // Map extracted values to their param name, casting the value if needed const providedParams = matches[0].slice(1) - if (routeParams.length > 0) { + + // @NOTE: refers to definiton e.g. '/page/{id}', not the actual params + if (routeParamsDefinition.length > 0) { const params = providedParams.reduce>( (acc, value, index) => { - const [name, transformName] = routeParams[index] + const [name, transformName] = routeParamsDefinition[index] const typeInfo = allParamTypes[transformName as SupportedRouterParamTypes]