Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/shield-controller/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Fixed and optimized shield-coverage-result polling with Cockatiel Policy from Controller-utils. ([#6847](https://github.com/MetaMask/core/pull/6847))

## [1.0.0]

### Added
Expand Down
4 changes: 3 additions & 1 deletion packages/shield-controller/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@
},
"dependencies": {
"@metamask/base-controller": "^9.0.0",
"@metamask/controller-utils": "^11.14.1",
"@metamask/messenger": "^0.3.0",
"@metamask/utils": "^11.8.1"
"@metamask/utils": "^11.8.1",
"cockatiel": "^3.1.2"
},
"devDependencies": {
"@babel/runtime": "^7.23.9",
Expand Down
36 changes: 33 additions & 3 deletions packages/shield-controller/src/backend.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ describe('ShieldRemoteBackend', () => {
expect(getAccessToken).toHaveBeenCalledTimes(1);
});

it('should throw on check coverage timeout', async () => {
it('should throw on check coverage timeout with coverage status', async () => {
const { backend, fetchMock } = setup({
getCoverageResultTimeout: 0,
getCoverageResultPollInterval: 0,
Expand All @@ -144,12 +144,42 @@ describe('ShieldRemoteBackend', () => {
// Mock get coverage result: result unavailable.
fetchMock.mockResolvedValue({
status: 404,
json: jest.fn().mockResolvedValue({ status: 'unavailable' }),
} as unknown as Response);

const txMeta = generateMockTxMeta();
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
'Timeout waiting for coverage result',
'Failed to get coverage result: 404',
);

// Waiting here ensures coverage of the unexpected error and lets us know
// that the polling loop is exited as expected.
await new Promise((resolve) => setTimeout(resolve, 10));
});

it('should throw on check coverage timeout', async () => {
const { backend, fetchMock } = setup({
getCoverageResultTimeout: 0,
getCoverageResultPollInterval: 0,
});

// Mock init coverage check.
fetchMock.mockResolvedValueOnce({
status: 200,
json: jest.fn().mockResolvedValue({ coverageId: 'coverageId' }),
} as unknown as Response);

// Mock get coverage result: result unavailable.
fetchMock.mockResolvedValue({
status: 412,
json: jest.fn().mockResolvedValue({
message: 'Results are not available yet',
statusCode: 412,
}),
} as unknown as Response);

const txMeta = generateMockTxMeta();
await expect(backend.checkCoverage({ txMeta })).rejects.toThrow(
'Failed to get coverage result: Results are not available yet',
);

// Waiting here ensures coverage of the unexpected error and lets us know
Expand Down
145 changes: 85 additions & 60 deletions packages/shield-controller/src/backend.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import {
ConstantBackoff,
DEFAULT_MAX_RETRIES,
HttpError,
} from '@metamask/controller-utils';
import {
EthMethod,
SignatureRequestType,
Expand All @@ -7,6 +12,7 @@ import type { TransactionMeta } from '@metamask/transaction-controller';
import type { Json } from '@metamask/utils';

import { SignTypedDataVersion } from './constants';
import { PollingWithCockatielPolicy } from './polling-with-policy';
import type {
CheckCoverageRequest,
CheckSignatureCoverageRequest,
Expand Down Expand Up @@ -56,14 +62,12 @@ export type GetCoverageResultResponse = {
export class ShieldRemoteBackend implements ShieldBackend {
readonly #getAccessToken: () => Promise<string>;

readonly #getCoverageResultTimeout: number;

readonly #getCoverageResultPollInterval: number;

readonly #baseUrl: string;

readonly #fetch: typeof globalThis.fetch;

readonly #pollingPolicy: PollingWithCockatielPolicy;

constructor({
getAccessToken,
getCoverageResultTimeout = 5000, // milliseconds
Expand All @@ -78,10 +82,18 @@ export class ShieldRemoteBackend implements ShieldBackend {
fetch: typeof globalThis.fetch;
}) {
this.#getAccessToken = getAccessToken;
this.#getCoverageResultTimeout = getCoverageResultTimeout;
this.#getCoverageResultPollInterval = getCoverageResultPollInterval;
this.#baseUrl = baseUrl;
this.#fetch = fetchFn;

const { backoff, maxRetries } = computePollingIntervalAndRetryCount(
getCoverageResultTimeout,
getCoverageResultPollInterval,
);

this.#pollingPolicy = new PollingWithCockatielPolicy({
backoff,
maxRetries,
});
}

async checkCoverage(req: CheckCoverageRequest): Promise<CoverageResult> {
Expand All @@ -95,9 +107,11 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

const txCoverageResultUrl = `${this.#baseUrl}/v1/transaction/coverage/result`;
const coverageResult = await this.#getCoverageResult(coverageId, {
coverageResultUrl: txCoverageResultUrl,
});
const coverageResult = await this.#getCoverageResult(
req.txMeta.id,
coverageId,
txCoverageResultUrl,
);
return {
coverageId,
message: coverageResult.message,
Expand All @@ -119,9 +133,11 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

const signatureCoverageResultUrl = `${this.#baseUrl}/v1/signature/coverage/result`;
const coverageResult = await this.#getCoverageResult(coverageId, {
coverageResultUrl: signatureCoverageResultUrl,
});
const coverageResult = await this.#getCoverageResult(
req.signatureRequest.id,
coverageId,
signatureCoverageResultUrl,
);
return {
coverageId,
message: coverageResult.message,
Expand All @@ -138,6 +154,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
...initBody,
};

// cancel the pending get coverage result request
this.#pollingPolicy.abortPendingRequest(req.signatureRequest.id);

const res = await this.#fetch(
`${this.#baseUrl}/v1/signature/coverage/log`,
{
Expand All @@ -159,6 +178,9 @@ export class ShieldRemoteBackend implements ShieldBackend {
...initBody,
};

// cancel the pending get coverage result request
this.#pollingPolicy.abortPendingRequest(req.txMeta.id);

const res = await this.#fetch(
`${this.#baseUrl}/v1/transaction/coverage/log`,
{
Expand Down Expand Up @@ -188,51 +210,39 @@ export class ShieldRemoteBackend implements ShieldBackend {
}

async #getCoverageResult(
requestId: string,
coverageId: string,
configs: {
coverageResultUrl: string;
timeout?: number;
pollInterval?: number;
},
coverageResultUrl: string,
): Promise<GetCoverageResultResponse> {
const reqBody: GetCoverageResultRequest = {
coverageId,
};

const timeout = configs?.timeout ?? this.#getCoverageResultTimeout;
const pollInterval =
configs?.pollInterval ?? this.#getCoverageResultPollInterval;

const headers = await this.#createHeaders();
return await new Promise((resolve, reject) => {
let timeoutReached = false;
setTimeout(() => {
timeoutReached = true;
reject(new Error('Timeout waiting for coverage result'));
}, timeout);

const poll = async (): Promise<GetCoverageResultResponse> => {
// The timeoutReached variable is modified in the timeout callback.
// eslint-disable-next-line no-unmodified-loop-condition
while (!timeoutReached) {
const startTime = Date.now();
const res = await this.#fetch(configs.coverageResultUrl, {
method: 'POST',
headers,
body: JSON.stringify(reqBody),
});
if (res.status === 200) {
return (await res.json()) as GetCoverageResultResponse;
}
await sleep(pollInterval - (Date.now() - startTime));
}
// The following line will not have an effect as the upper level promise
// will already be rejected by now.
throw new Error('unexpected error');
};

poll().then(resolve).catch(reject);
});

const getCoverageResultFn = async (signal: AbortSignal) => {
const res = await this.#fetch(coverageResultUrl, {
method: 'POST',
headers,
body: JSON.stringify(reqBody),
signal,
});
if (res.status === 200) {
return (await res.json()) as GetCoverageResultResponse;
}

// parse the error message from the response body
let errorMessage = 'Timeout waiting for coverage result';
try {
const errorJson = await res.json();
errorMessage = `Failed to get coverage result: ${errorJson.message || errorJson.status}`;
} catch {
errorMessage = `Failed to get coverage result: ${res.status}`;
}
throw new HttpError(res.status, errorMessage);
};

return this.#pollingPolicy.start(requestId, getCoverageResultFn);
}

async #createHeaders() {
Expand All @@ -244,16 +254,6 @@ export class ShieldRemoteBackend implements ShieldBackend {
}
}

/**
* Sleep for a specified amount of time.
*
* @param ms - The number of milliseconds to sleep.
* @returns A promise that resolves after the specified amount of time.
*/
async function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}

/**
* Make the body for the init coverage check request.
*
Expand Down Expand Up @@ -324,3 +324,28 @@ export function parseSignatureRequestMethod(

return signatureRequest.type;
}

/**
* Compute the polling interval and retry count for the Cockatiel policy based on the timeout and poll interval given.
*
* @param timeout - The timeout in milliseconds.
* @param pollInterval - The poll interval in milliseconds.
* @returns The polling interval and retry count.
*/
function computePollingIntervalAndRetryCount(
timeout: number,
pollInterval: number,
) {
const backoff = new ConstantBackoff(pollInterval);
const computedMaxRetries = Math.floor(timeout / pollInterval) + 1;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Timeout Exceeded Due to Retry Calculation Error

The maxRetries calculation in computePollingIntervalAndRetryCount is off by one. It currently calculates total attempts (Math.floor(timeout / pollInterval) + 1), but the Cockatiel policy expects the number of retries (attempts - 1). This leads to an extra retry, causing polling to exceed the intended timeout.

Fix in Cursor Fix in Web

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intended.


const maxRetries =
isNaN(computedMaxRetries) || !isFinite(computedMaxRetries)
? DEFAULT_MAX_RETRIES
: computedMaxRetries;

return {
backoff,
maxRetries,
};
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Timeout Exceeded Due to Retry Counting Error

The maxRetries calculation in computePollingIntervalAndRetryCount has an off-by-one error. It effectively counts the initial attempt as a retry, leading to one extra poll and causing the total polling duration to exceed the specified timeout.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Edge Case Handling in Retry Calculation

The maxRetries calculation in computePollingIntervalAndRetryCount has edge cases when pollInterval is 0 or when both timeout and pollInterval are 0. Division by zero results in NaN or Infinity, causing the function to incorrectly fall back to DEFAULT_MAX_RETRIES. This prevents respecting intended behaviors like immediate polling or immediate failure, leading to unexpected retry counts.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Polling Interval Miscalculation

The computePollingIntervalAndRetryCount function calculates maxRetries as Math.floor(timeout / pollInterval) + 1. This adds an extra retry, causing the total polling duration to exceed the specified timeout. This can lead to longer-than-expected waits, particularly with small poll intervals.

Fix in Cursor Fix in Web

Loading
Loading