Skip to content

Commit 6868cbe

Browse files
vmiuraabhipatel12
andauthored
fix(a2a): Don't mutate 'replace' tool args in scheduleToolCalls (google-gemini#7369)
Co-authored-by: Abhi <43648792+abhipatel12@users.noreply.github.com>
1 parent f2bddfe commit 6868cbe

File tree

4 files changed

+121
-36
lines changed

4 files changed

+121
-36
lines changed

packages/a2a-server/src/agent.test.ts

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import {
3333
assertTaskCreationAndWorkingStatus,
3434
createStreamMessageRequest,
3535
MockTool,
36+
createMockConfig,
3637
} from './testing_utils.js';
3738

3839
const mockToolConfirmationFn = async () =>
@@ -68,26 +69,11 @@ vi.mock('./config.js', async () => {
6869
return {
6970
...actual,
7071
loadConfig: vi.fn().mockImplementation(async () => {
71-
config = {
72+
const mockConfig = createMockConfig({
7273
getToolRegistry: getToolRegistrySpy,
7374
getApprovalMode: getApprovalModeSpy,
74-
getIdeMode: vi.fn().mockReturnValue(false),
75-
getAllowedTools: vi.fn().mockReturnValue([]),
76-
getIdeClient: vi.fn(),
77-
getWorkspaceContext: vi.fn().mockReturnValue({
78-
isPathWithinWorkspace: () => true,
79-
}),
80-
getTargetDir: () => '/test',
81-
getGeminiClient: vi.fn(),
82-
getDebugMode: vi.fn().mockReturnValue(false),
83-
getContentGeneratorConfig: vi
84-
.fn()
85-
.mockReturnValue({ model: 'gemini-pro' }),
86-
getModel: vi.fn().mockReturnValue('gemini-pro'),
87-
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
88-
setFlashFallbackHandler: vi.fn(),
89-
initialize: vi.fn().mockResolvedValue(undefined),
90-
} as unknown as Config;
75+
});
76+
config = mockConfig as Config;
9177
return config;
9278
}),
9379
};
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/**
2+
* @license
3+
* Copyright 2025 Google LLC
4+
* SPDX-License-Identifier: Apache-2.0
5+
*/
6+
7+
import { describe, it, expect, vi } from 'vitest';
8+
import { Task } from './task.js';
9+
import type { Config, ToolCallRequestInfo } from '@google/gemini-cli-core';
10+
import { createMockConfig } from './testing_utils.js';
11+
import type { ExecutionEventBus } from '@a2a-js/sdk/server';
12+
13+
describe('Task', () => {
14+
it('scheduleToolCalls should not modify the input requests array', async () => {
15+
const mockConfig = createMockConfig();
16+
17+
const mockEventBus: ExecutionEventBus = {
18+
publish: vi.fn(),
19+
on: vi.fn(),
20+
off: vi.fn(),
21+
once: vi.fn(),
22+
removeAllListeners: vi.fn(),
23+
finished: vi.fn(),
24+
};
25+
26+
// The Task constructor is private. We'll bypass it for this unit test.
27+
// @ts-expect-error - Calling private constructor for test purposes.
28+
const task = new Task(
29+
'task-id',
30+
'context-id',
31+
mockConfig as Config,
32+
mockEventBus,
33+
);
34+
35+
task['setTaskStateAndPublishUpdate'] = vi.fn();
36+
task['getProposedContent'] = vi.fn().mockResolvedValue('new content');
37+
38+
const requests: ToolCallRequestInfo[] = [
39+
{
40+
callId: '1',
41+
name: 'replace',
42+
args: {
43+
file_path: 'test.txt',
44+
old_string: 'old',
45+
new_string: 'new',
46+
},
47+
isClientInitiated: false,
48+
prompt_id: 'prompt-id-1',
49+
},
50+
];
51+
52+
const originalRequests = JSON.parse(JSON.stringify(requests));
53+
const abortController = new AbortController();
54+
55+
await task.scheduleToolCalls(requests, abortController.signal);
56+
57+
expect(requests).toEqual(originalRequests);
58+
});
59+
});

packages/a2a-server/src/task.ts

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -520,30 +520,36 @@ export class Task {
520520
return;
521521
}
522522

523-
for (const request of requests) {
524-
if (
525-
!request.args['newContent'] &&
526-
request.name === 'replace' &&
527-
request.args &&
528-
request.args['file_path'] &&
529-
request.args['old_string'] &&
530-
request.args['new_string']
531-
) {
532-
request.args['newContent'] = await this.getProposedContent(
533-
request.args['file_path'] as string,
534-
request.args['old_string'] as string,
535-
request.args['new_string'] as string,
536-
);
537-
}
538-
}
523+
const updatedRequests = await Promise.all(
524+
requests.map(async (request) => {
525+
if (
526+
request.name === 'replace' &&
527+
request.args &&
528+
!request.args['newContent'] &&
529+
request.args['file_path'] &&
530+
request.args['old_string'] &&
531+
request.args['new_string']
532+
) {
533+
const newContent = await this.getProposedContent(
534+
request.args['file_path'] as string,
535+
request.args['old_string'] as string,
536+
request.args['new_string'] as string,
537+
);
538+
return { ...request, args: { ...request.args, newContent } };
539+
}
540+
return request;
541+
}),
542+
);
539543

540-
logger.info(`[Task] Scheduling batch of ${requests.length} tool calls.`);
544+
logger.info(
545+
`[Task] Scheduling batch of ${updatedRequests.length} tool calls.`,
546+
);
541547
const stateChange: StateChange = {
542548
kind: CoderAgentEvent.StateChangeEvent,
543549
};
544550
this.setTaskStateAndPublishUpdate('working', stateChange);
545551

546-
await this.scheduler.schedule(requests, abortSignal);
552+
await this.scheduler.schedule(updatedRequests, abortSignal);
547553
}
548554

549555
async acceptAgentMessage(event: ServerGeminiStreamEvent): Promise<void> {

packages/a2a-server/src/testing_utils.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,52 @@ import type {
99
TaskStatusUpdateEvent,
1010
SendStreamingMessageSuccessResponse,
1111
} from '@a2a-js/sdk';
12+
import { ApprovalMode } from '@google/gemini-cli-core';
1213
import {
1314
BaseDeclarativeTool,
1415
BaseToolInvocation,
1516
Kind,
1617
} from '@google/gemini-cli-core';
1718
import type {
19+
Config,
1820
ToolCallConfirmationDetails,
1921
ToolResult,
2022
ToolInvocation,
2123
} from '@google/gemini-cli-core';
2224
import { expect, vi } from 'vitest';
2325

26+
export function createMockConfig(
27+
overrides: Partial<Config> = {},
28+
): Partial<Config> {
29+
const mockConfig = {
30+
getToolRegistry: vi.fn().mockReturnValue({
31+
getTool: vi.fn(),
32+
getAllToolNames: vi.fn().mockReturnValue([]),
33+
}),
34+
getApprovalMode: vi.fn().mockReturnValue(ApprovalMode.DEFAULT),
35+
getIdeMode: vi.fn().mockReturnValue(false),
36+
getAllowedTools: vi.fn().mockReturnValue([]),
37+
getIdeClient: vi.fn(),
38+
getWorkspaceContext: vi.fn().mockReturnValue({
39+
isPathWithinWorkspace: () => true,
40+
}),
41+
getTargetDir: () => '/test',
42+
getGeminiClient: vi.fn(),
43+
getDebugMode: vi.fn().mockReturnValue(false),
44+
getContentGeneratorConfig: vi.fn().mockReturnValue({ model: 'gemini-pro' }),
45+
getModel: vi.fn().mockReturnValue('gemini-pro'),
46+
getUsageStatisticsEnabled: vi.fn().mockReturnValue(false),
47+
setFlashFallbackHandler: vi.fn(),
48+
initialize: vi.fn().mockResolvedValue(undefined),
49+
getProxy: vi.fn().mockReturnValue(undefined),
50+
getHistory: vi.fn().mockReturnValue([]),
51+
getEmbeddingModel: vi.fn().mockReturnValue('text-embedding-004'),
52+
getSessionId: vi.fn().mockReturnValue('test-session-id'),
53+
...overrides,
54+
};
55+
return mockConfig;
56+
}
57+
2458
export const mockOnUserConfirmForToolConfirmation = vi.fn();
2559

2660
export class MockToolInvocation extends BaseToolInvocation<object, ToolResult> {

0 commit comments

Comments
 (0)