Skip to content

Commit 9e071b6

Browse files
yaron2WhitWaldo
andauthored
Add A2A Dapr Task Store (#1593)
Signed-off-by: yaron2 <schneider.yaron@live.com> Co-authored-by: Whit Waldo <whit.waldo@innovian.net>
1 parent ae4b509 commit 9e071b6

File tree

3 files changed

+287
-0
lines changed

3 files changed

+287
-0
lines changed

Directory.Packages.props

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
<CentralPackageTransitivePinningEnabled>true</CentralPackageTransitivePinningEnabled>
55
</PropertyGroup>
66
<ItemGroup>
7+
<PackageVersion Include="A2A" Version="0.1.0-preview.2" />
78
<PackageVersion Include="BenchmarkDotNet" Version="0.15.2" />
89
<PackageVersion Include="coverlet.collector" Version="6.0.4" />
910
<PackageVersion Include="coverlet.msbuild" Version="6.0.4" />
@@ -48,6 +49,7 @@
4849
<PackageVersion Include="Serilog.Sinks.File" Version="7.0.0" />
4950
<PackageVersion Include="Shouldly" Version="4.3.0" />
5051
<PackageVersion Include="System.Formats.Asn1" Version="9.0.6" />
52+
<PackageVersion Include="System.Net.ServerSentEvents" Version="10.0.0-preview.6.25358.103" />
5153
<PackageVersion Include="System.Text.Json" Version="9.0.6" />
5254
<PackageVersion Include="xunit" Version="2.9.3" />
5355
<PackageVersion Include="xunit.extensibility.core" Version="2.9.3" />

src/Dapr.AI.A2A/Dapr.AI.A2A.csproj

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<ImplicitUsings>enable</ImplicitUsings>
5+
<Nullable>enable</Nullable>
6+
<PackageId>Dapr.AI.A2a</PackageId>
7+
<Title>Dapr AI Agent to Agent SDK</Title>
8+
<Description>Dapr SDK for implementing agent-to-agent operations.</Description>
9+
<VersionSuffix>alpha</VersionSuffix>
10+
</PropertyGroup>
11+
12+
<PropertyGroup>
13+
<SignAssembly>false</SignAssembly>
14+
</PropertyGroup>
15+
16+
<ItemGroup>
17+
<FrameworkReference Include="Microsoft.AspNetCore.App" />
18+
</ItemGroup>
19+
20+
<ItemGroup>
21+
<PackageReference Include="A2A" />
22+
<PackageReference Include="Google.Protobuf" />
23+
<PackageReference Include="System.Net.ServerSentEvents" />
24+
</ItemGroup>
25+
26+
<ItemGroup>
27+
<ProjectReference Include="..\Dapr.Common\Dapr.Common.csproj" />
28+
<ProjectReference Include="..\Dapr.Protos\Dapr.Protos.csproj" />
29+
<ProjectReference Include="..\Dapr.Client\Dapr.Client.csproj" />
30+
</ItemGroup>
31+
32+
</Project>

src/Dapr.AI.A2A/DaprTaskStore.cs

Lines changed: 253 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,253 @@
1+
using A2A;
2+
using Dapr.Client;
3+
using System.Collections.Concurrent;
4+
5+
namespace Dapr.AI.A2A;
6+
7+
/// <summary>
8+
/// Represents a Dapr-backed state store implementation for the A2A Dotnet ITaskStore interface, allowing agents to keep persistent state.
9+
/// </summary>
10+
public class DaprTaskStore : ITaskStore
11+
{
12+
private readonly DaprClient _daprClient;
13+
private readonly string _stateStoreName;
14+
15+
private static string PushCfgKey(string taskId, string configId) => $"a2a:task:{taskId}:pushcfg:{configId}";
16+
private static string PushCfgIndexKey(string taskId) => $"a2a:task:{taskId}:pushcfg:index";
17+
18+
// Dapr state operation settings: strong consistency, with default concurrency (override per operation as needed)
19+
private static readonly StateOptions StrongConsistency = new StateOptions
20+
{
21+
Consistency = ConsistencyMode.Strong // Ensure reads/writes use strong consistency
22+
};
23+
24+
/// <summary>
25+
/// Constructor for the Task Store.
26+
/// </summary>
27+
/// <param name="daprClient">A Dapr Client insance</param>
28+
/// <param name="stateStoreName">The name of the state store component to use</param>
29+
public DaprTaskStore(DaprClient daprClient, string stateStoreName = "statestore")
30+
{
31+
_daprClient = daprClient ?? throw new ArgumentNullException(nameof(daprClient));
32+
_stateStoreName = stateStoreName;
33+
}
34+
35+
/// <summary>
36+
/// Retrieves a task by its ID.
37+
/// </summary>
38+
/// <param name="taskId">The ID of the task to retrieve.</param>
39+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
40+
/// <returns>The task if found, null otherwise.</returns>
41+
public async Task<AgentTask?> GetTaskAsync(string taskId, CancellationToken cancellationToken = default)
42+
{
43+
if (taskId == null) throw new ArgumentNullException(nameof(taskId));
44+
45+
// Retrieve the AgentTask from Dapr state store with strong consistency to get the latest data
46+
AgentTask? task = await _daprClient.GetStateAsync<AgentTask>(
47+
_stateStoreName,
48+
key: taskId,
49+
consistencyMode: ConsistencyMode.Strong,
50+
metadata: null,
51+
cancellationToken: cancellationToken);
52+
return task;
53+
}
54+
55+
/// <summary>
56+
/// Stores or updates a task.
57+
/// </summary>
58+
/// <param name="task">The task to store.</param>
59+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
60+
/// <returns>A task representing the operation.</returns>
61+
public async Task SetTaskAsync(AgentTask task, CancellationToken cancellationToken = default)
62+
{
63+
if (task == null) throw new ArgumentNullException(nameof(task));
64+
// The task.Id will be used as the key. We save the entire AgentTask object.
65+
// Use strong consistency on write; concurrency defaults to last-write-wins for new entries.
66+
await _daprClient.SaveStateAsync(
67+
_stateStoreName,
68+
key: task.Id,
69+
value: task,
70+
stateOptions: StrongConsistency, // strong consistency ensures durability before ack
71+
metadata: null,
72+
cancellationToken: cancellationToken
73+
);
74+
// Note: If the task already existed, this will overwrite it (last-write-wins behavior since no ETag used).
75+
}
76+
77+
/// <summary>
78+
/// Updates the status of a task.
79+
/// </summary>
80+
/// <param name="taskId">The ID of the task.</param>
81+
/// <param name="status">The new status.</param>
82+
/// <param name="message">Optional message associated with the status.</param>
83+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
84+
/// <returns>The updated task status.</returns>
85+
public async Task<AgentTaskStatus> UpdateStatusAsync(string taskId, TaskState status, Message? message = null, CancellationToken cancellationToken = default)
86+
{
87+
if (taskId == null) throw new ArgumentNullException(nameof(taskId));
88+
// Fetch state with its ETag for concurrency control.
89+
// We use strong consistency to get the latest state and ETag.
90+
var (existingTask, etag) = await _daprClient.GetStateAndETagAsync<AgentTask>(
91+
_stateStoreName,
92+
key: taskId,
93+
consistencyMode: ConsistencyMode.Strong,
94+
metadata: null,
95+
cancellationToken: cancellationToken);
96+
if (existingTask == null)
97+
{
98+
throw new KeyNotFoundException($"Task with ID '{taskId}' not found.");
99+
}
100+
101+
// Update the status field of the retrieved task object.
102+
var st = existingTask.Status;
103+
st.State = status;
104+
if (message != null)
105+
{
106+
st.Message = message;
107+
}
108+
109+
existingTask.Status = st;
110+
111+
// Attempt to save the updated task back with the ETag for optimistic concurrency.
112+
var stateOptions = new StateOptions
113+
{
114+
Consistency = ConsistencyMode.Strong,
115+
Concurrency = ConcurrencyMode.FirstWrite // enable first-write-wins
116+
};
117+
bool saved = await _daprClient.TrySaveStateAsync(
118+
_stateStoreName,
119+
key: taskId,
120+
value: existingTask,
121+
etag: etag, // use ETag to ensure no concurrent modification
122+
stateOptions: stateOptions,
123+
metadata: null,
124+
cancellationToken: cancellationToken
125+
);
126+
if (!saved)
127+
{
128+
// The save failed due to an ETag mismatch (concurrent update happened).
129+
throw new InvalidOperationException($"Concurrent update detected for task '{taskId}'. Update was not saved.");
130+
}
131+
132+
return existingTask.Status;
133+
}
134+
135+
/// <summary>
136+
/// Retrieves push notification configuration for a task.
137+
/// </summary>
138+
/// <param name="taskId">The ID of the task.</param>
139+
/// <param name="notificationConfigId">The ID of the push notification configuration.</param>
140+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
141+
/// <returns>The push notification configuration if found, null otherwise.</returns>
142+
public async Task<TaskPushNotificationConfig?> GetPushNotificationAsync(string taskId, string notificationConfigId, CancellationToken cancellationToken = default)
143+
{
144+
if (string.IsNullOrWhiteSpace(taskId)) throw new ArgumentNullException(nameof(taskId));
145+
if (string.IsNullOrWhiteSpace(notificationConfigId)) throw new ArgumentNullException(nameof(notificationConfigId));
146+
147+
return await _daprClient.GetStateAsync<TaskPushNotificationConfig>(
148+
_stateStoreName,
149+
key: PushCfgKey(taskId, notificationConfigId),
150+
consistencyMode: ConsistencyMode.Strong,
151+
metadata: null,
152+
cancellationToken: cancellationToken);
153+
}
154+
155+
/// <summary>
156+
/// Stores push notification configuration for a task.
157+
/// </summary>
158+
/// <param name="pushNotificationConfig">The push notification configuration.</param>
159+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
160+
/// <returns>A task representing the operation.</returns>
161+
public async Task SetPushNotificationConfigAsync(TaskPushNotificationConfig pushNotificationConfig, CancellationToken cancellationToken = default)
162+
{
163+
if (pushNotificationConfig is null) throw new ArgumentNullException(nameof(pushNotificationConfig));
164+
165+
// Adjust these property names if your model differs:
166+
var taskId = pushNotificationConfig.TaskId ?? throw new ArgumentException("Config.TaskId is required.");
167+
var configId = pushNotificationConfig.PushNotificationConfig.Id ?? throw new ArgumentException("Config.Id is required.");
168+
169+
//Save/Upsert the config itself
170+
await _daprClient.SaveStateAsync(
171+
_stateStoreName,
172+
key: PushCfgKey(taskId, configId),
173+
value: pushNotificationConfig,
174+
stateOptions: StrongConsistency,
175+
metadata: null,
176+
cancellationToken: cancellationToken);
177+
178+
// Add the configId to the per-task index with ETag (avoid races)
179+
for (var attempt = 0; attempt < 5; attempt++)
180+
{
181+
var (index, etag) = await _daprClient.GetStateAndETagAsync<string[]>(
182+
_stateStoreName,
183+
key: PushCfgIndexKey(taskId),
184+
consistencyMode: ConsistencyMode.Strong,
185+
metadata: null,
186+
cancellationToken: cancellationToken);
187+
188+
var list = (index ?? Array.Empty<string>()).ToList();
189+
if (!list.Contains(configId, StringComparer.Ordinal))
190+
list.Add(configId);
191+
192+
var ok = await _daprClient.TrySaveStateAsync(
193+
_stateStoreName,
194+
key: PushCfgIndexKey(taskId),
195+
value: list.ToArray(),
196+
etag: etag,
197+
stateOptions: new StateOptions { Consistency = ConsistencyMode.Strong, Concurrency = ConcurrencyMode.FirstWrite },
198+
metadata: null,
199+
cancellationToken: cancellationToken);
200+
201+
if (ok) break;
202+
203+
// small backoff before retry
204+
await Task.Delay(50 * (attempt + 1), cancellationToken);
205+
}
206+
}
207+
208+
/// <summary>
209+
/// Retrieves push notification configuration for a task.
210+
/// </summary>
211+
/// <param name="taskId">The ID of the task.</param>
212+
/// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
213+
/// <returns>The push notification configuration if found, null otherwise.</returns>
214+
public async Task<IEnumerable<TaskPushNotificationConfig>> GetPushNotificationsAsync(string taskId, CancellationToken cancellationToken = default)
215+
{
216+
if (string.IsNullOrWhiteSpace(taskId)) throw new ArgumentNullException(nameof(taskId));
217+
218+
var ids = await _daprClient.GetStateAsync<string[]>(
219+
_stateStoreName,
220+
key: PushCfgIndexKey(taskId),
221+
consistencyMode: ConsistencyMode.Strong,
222+
metadata: null,
223+
cancellationToken: cancellationToken) ?? Array.Empty<string>();
224+
225+
if (ids.Length == 0) return Array.Empty<TaskPushNotificationConfig>();
226+
227+
const int maxParallel = 8;
228+
using var gate = new SemaphoreSlim(maxParallel);
229+
var bag = new ConcurrentBag<TaskPushNotificationConfig>();
230+
231+
await Task.WhenAll(ids.Select(async id =>
232+
{
233+
await gate.WaitAsync(cancellationToken);
234+
try
235+
{
236+
var cfg = await _daprClient.GetStateAsync<TaskPushNotificationConfig>(
237+
_stateStoreName,
238+
key: PushCfgKey(taskId, id),
239+
consistencyMode: ConsistencyMode.Strong,
240+
metadata: null,
241+
cancellationToken: cancellationToken);
242+
243+
if (cfg is not null) bag.Add(cfg);
244+
}
245+
finally
246+
{
247+
gate.Release();
248+
}
249+
}));
250+
251+
return bag.ToArray();
252+
}
253+
}

0 commit comments

Comments
 (0)