Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion lib/core/decision_service/cmab/cmab_service.spec.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,29 @@
import { describe, it, expect, vi, Mocked, Mock, MockInstance, beforeEach, afterEach } from 'vitest';
/**
* Copyright 2025, Optimizely
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { describe, it, expect, vi } from 'vitest';

import { DefaultCmabService } from './cmab_service';
import { getMockSyncCache } from '../../../tests/mock/mock_cache';
import { ProjectConfig } from '../../../project_config/project_config';
import { OptimizelyDecideOption, UserAttributes } from '../../../shared_types';
import OptimizelyUserContext from '../../../optimizely_user_context';
import { validate as uuidValidate } from 'uuid';
import { resolvablePromise } from '../../../utils/promise/resolvablePromise';
import { exhaustMicrotasks } from '../../../tests/testUtils';

const mockProjectConfig = (): ProjectConfig => ({
experimentIdMap: {
Expand Down Expand Up @@ -418,4 +436,75 @@ describe('DefaultCmabService', () => {

expect(mockCmabClient.fetchDecision).toHaveBeenCalledTimes(2);
});

it('should serialize concurrent calls to getDecision with the same userId and ruleId', async () => {
const nCall = 10;
let currentVar = 123;
const fetchPromises = Array.from({ length: nCall }, () => resolvablePromise());

let callCount = 0;
const mockCmabClient = {
fetchDecision: vi.fn().mockImplementation(async () => {
const variation = `${currentVar++}`;
await fetchPromises[callCount++];
return variation;
}),
};

const cmabService = new DefaultCmabService({
cmabCache: getMockSyncCache(),
cmabClient: mockCmabClient,
});

const projectConfig = mockProjectConfig();
const userContext = mockUserContext('user123', {});

const resultPromises = [];
for (let i = 0; i < nCall; i++) {
resultPromises.push(cmabService.getDecision(projectConfig, userContext, '1234', {}));
}

await exhaustMicrotasks();

expect(mockCmabClient.fetchDecision).toHaveBeenCalledTimes(1);

for(let i = 0; i < nCall; i++) {
fetchPromises[i].resolve('');
await exhaustMicrotasks();
const result = await resultPromises[i];
expect(result.variationId).toBe('123');
expect(mockCmabClient.fetchDecision).toHaveBeenCalledTimes(1);
}
});

it('should not serialize calls to getDecision with different userId or ruleId', async () => {
let currentVar = 123;
const mockCmabClient = {
fetchDecision: vi.fn().mockImplementation(() => Promise.resolve(`${currentVar++}`)),
};

const cmabService = new DefaultCmabService({
cmabCache: getMockSyncCache(),
cmabClient: mockCmabClient,
});

const projectConfig = mockProjectConfig();
const userContext1 = mockUserContext('user123', {});
const userContext2 = mockUserContext('user456', {});

const resultPromises = [];
resultPromises.push(cmabService.getDecision(projectConfig, userContext1, '1234', {}));
resultPromises.push(cmabService.getDecision(projectConfig, userContext1, '5678', {}));
resultPromises.push(cmabService.getDecision(projectConfig, userContext2, '1234', {}));
resultPromises.push(cmabService.getDecision(projectConfig, userContext2, '5678', {}));

await exhaustMicrotasks();

expect(mockCmabClient.fetchDecision).toHaveBeenCalledTimes(4);

for(let i = 0; i < resultPromises.length; i++) {
const result = await resultPromises[i];
expect(result.variationId).toBe(`${123 + i}`);
}
});
});
24 changes: 24 additions & 0 deletions lib/core/decision_service/cmab/cmab_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { CmabClient } from "./cmab_client";
import { v4 as uuidV4 } from 'uuid';
import murmurhash from "murmurhash";
import { DecideOptionsMap } from "..";
import { SerialRunner } from "../../../utils/executor/serial_runner";

export type CmabDecision = {
variationId: string,
Expand Down Expand Up @@ -57,22 +58,45 @@ export type CmabServiceOptions = {
cmabClient: CmabClient;
}

const SERIALIZER_BUCKETS = 1000;

export class DefaultCmabService implements CmabService {
private cmabCache: CacheWithRemove<CmabCacheValue>;
private cmabClient: CmabClient;
private logger?: LoggerFacade;
private serializers: SerialRunner[] = Array.from(
{ length: SERIALIZER_BUCKETS }, () => new SerialRunner()
);

constructor(options: CmabServiceOptions) {
this.cmabCache = options.cmabCache;
this.cmabClient = options.cmabClient;
this.logger = options.logger;
}

private getSerializerIndex(userId: string, experimentId: string): number {
const key = this.getCacheKey(userId, experimentId);
const hash = murmurhash.v3(key);
return Math.abs(hash) % SERIALIZER_BUCKETS;
}

async getDecision(
projectConfig: ProjectConfig,
userContext: IOptimizelyUserContext,
ruleId: string,
options: DecideOptionsMap,
): Promise<CmabDecision> {
const serializerIndex = this.getSerializerIndex(userContext.getUserId(), ruleId);
return this.serializers[serializerIndex].run(() =>
this.getDecisionInternal(projectConfig, userContext, ruleId, options)
);
}

private async getDecisionInternal(
projectConfig: ProjectConfig,
userContext: IOptimizelyUserContext,
ruleId: string,
options: DecideOptionsMap,
): Promise<CmabDecision> {
const filteredAttributes = this.filterAttributes(projectConfig, userContext, ruleId);

Expand Down
195 changes: 195 additions & 0 deletions lib/utils/executor/serial_runner.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
/**
* Copyright 2025, Optimizely
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import { describe, it, expect, vi, beforeEach } from 'vitest';

import { SerialRunner } from './serial_runner';
import { resolvablePromise } from '../promise/resolvablePromise';
import { exhaustMicrotasks } from '../../tests/testUtils';
import { Maybe } from '../type';

describe('SerialRunner', () => {
let serialRunner: SerialRunner;

beforeEach(() => {
serialRunner = new SerialRunner();
});

it('should return result from a single async function', async () => {
const fn = () => Promise.resolve('result');

const result = await serialRunner.run(fn);

expect(result).toBe('result');
});

it('should reject with same error when the passed function rejects', async () => {
const error = new Error('test error');
const fn = () => Promise.reject(error);

await expect(serialRunner.run(fn)).rejects.toThrow(error);
});

it('should execute multiple async functions in order', async () => {
// events to track execution order
// begin_1 means call 1 started
// end_1 means call 1 ended ...
const events: string[] = [];

const nCall = 10;

const promises = Array.from({ length: nCall }, () => resolvablePromise());
const getFn = (i: number) => {
return async (): Promise<number> => {
events.push(`begin_${i}`);
await promises[i];
events.push(`end_${i}`);
return i;
}
}

const resultPromises = [];
for (let i = 0; i < nCall; i++) {
resultPromises.push(serialRunner.run(getFn(i)));
}

await exhaustMicrotasks();

const expectedEvents = ['begin_0'];

expect(events).toEqual(expectedEvents);

for(let i = 0; i < nCall - 1; i++) {
promises[i].resolve('');
await exhaustMicrotasks();

expectedEvents.push(`end_${i}`);
expectedEvents.push(`begin_${i+1}`);

expect(events).toEqual(expectedEvents);
}

promises[nCall - 1].resolve('');
await exhaustMicrotasks();

expectedEvents.push(`end_${nCall - 1}`);
expect(events).toEqual(expectedEvents);

for(let i = 0; i < nCall; i++) {
await expect(resultPromises[i]).resolves.toBe(i);
}
});

it('should continue execution even if one function throws an error', async () => {
const events: string[] = [];

const nCall = 5;
const err = [false, true, false, true, true];

const promises = Array.from({ length: nCall }, () => resolvablePromise());

const getFn = (i: number) => {
return async (): Promise<number> => {
events.push(`begin_${i}`);
let err = false;
try {
await promises[i];
} catch(e) {
err = true;
}

events.push(`end_${i}`);
if (err) {
throw new Error(`error_${i}`);
}
return i;
}
}

const resultPromises = [];
for (let i = 0; i < nCall; i++) {
resultPromises.push(serialRunner.run(getFn(i)));
}

await exhaustMicrotasks();

const expectedEvents = ['begin_0'];

expect(events).toEqual(expectedEvents);

const endFn = (i: number) => {
if (err[i]) {
promises[i].reject(new Error('error'));
} else {
promises[i].resolve('');
}
}

for(let i = 0; i < nCall - 1; i++) {
endFn(i);

await exhaustMicrotasks();

expectedEvents.push(`end_${i}`);
expectedEvents.push(`begin_${i+1}`);

expect(events).toEqual(expectedEvents);
}

endFn(nCall - 1);
await exhaustMicrotasks();

expectedEvents.push(`end_${nCall - 1}`);
expect(events).toEqual(expectedEvents);

for(let i = 0; i < nCall; i++) {
if (err[i]) {
await expect(resultPromises[i]).rejects.toThrow(`error_${i}`);
} else {
await expect(resultPromises[i]).resolves.toBe(i);
}
}
});

it('should handle functions that return different types', async () => {
const numberFn = () => Promise.resolve(42);
const stringFn = () => Promise.resolve('hello');
const objectFn = () => Promise.resolve({ key: 'value' });
const arrayFn = () => Promise.resolve([1, 2, 3]);
const booleanFn = () => Promise.resolve(true);
const nullFn = () => Promise.resolve(null);
const undefinedFn = () => Promise.resolve(undefined);

const results = await Promise.all([
serialRunner.run(numberFn),
serialRunner.run(stringFn),
serialRunner.run(objectFn),
serialRunner.run(arrayFn),
serialRunner.run(booleanFn),
serialRunner.run(nullFn),
serialRunner.run(undefinedFn),
]);

expect(results).toEqual([42, 'hello', { key: 'value' }, [1, 2, 3], true, null, undefined]);
});

it('should handle empty function that returns undefined', async () => {
const emptyFn = () => Promise.resolve(undefined);

const result = await serialRunner.run(emptyFn);

expect(result).toBeUndefined();
});
});
36 changes: 36 additions & 0 deletions lib/utils/executor/serial_runner.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/**
* Copyright 2025, Optimizely
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { AsyncProducer } from "../type";

class SerialRunner {
private waitPromise: Promise<unknown> = Promise.resolve();

// each call to serialize adds a new function to the end of the promise chain
// the function is called when the previous promise resolves
// if the function throws, the error is caught and ignored to allow the chain to continue
// the result of the function is returned as a promise
// if multiple calls to serialize are made, they will be executed in order
// even if some of them throw errors

run<T>(fn: AsyncProducer<T>): Promise<T> {
const resultPromise = this.waitPromise.then(fn);
this.waitPromise = resultPromise.catch(() => {});
return resultPromise;
}
}

export { SerialRunner };
Loading