1+ using A2A ;
2+ using Dapr . Client ;
3+ using System . Collections . Concurrent ;
4+
5+ namespace Dapr . AI . A2A ;
6+
7+ /// <summary>
8+ /// Represents a Dapr-backed state store implementation for the A2A Dotnet ITaskStore interface, allowing agents to keep persistent state.
9+ /// </summary>
10+ public class DaprTaskStore : ITaskStore
11+ {
12+ private readonly DaprClient _daprClient ;
13+ private readonly string _stateStoreName ;
14+
15+ private static string PushCfgKey ( string taskId , string configId ) => $ "a2a:task:{ taskId } :pushcfg:{ configId } ";
16+ private static string PushCfgIndexKey ( string taskId ) => $ "a2a:task:{ taskId } :pushcfg:index";
17+
18+ // Dapr state operation settings: strong consistency, with default concurrency (override per operation as needed)
19+ private static readonly StateOptions StrongConsistency = new StateOptions
20+ {
21+ Consistency = ConsistencyMode . Strong // Ensure reads/writes use strong consistency
22+ } ;
23+
24+ /// <summary>
25+ /// Constructor for the Task Store.
26+ /// </summary>
27+ /// <param name="daprClient">A Dapr Client insance</param>
28+ /// <param name="stateStoreName">The name of the state store component to use</param>
29+ public DaprTaskStore ( DaprClient daprClient , string stateStoreName = "statestore" )
30+ {
31+ _daprClient = daprClient ?? throw new ArgumentNullException ( nameof ( daprClient ) ) ;
32+ _stateStoreName = stateStoreName ;
33+ }
34+
35+ /// <summary>
36+ /// Retrieves a task by its ID.
37+ /// </summary>
38+ /// <param name="taskId">The ID of the task to retrieve.</param>
39+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
40+ /// <returns>The task if found, null otherwise.</returns>
41+ public async Task < AgentTask ? > GetTaskAsync ( string taskId , CancellationToken cancellationToken = default )
42+ {
43+ if ( taskId == null ) throw new ArgumentNullException ( nameof ( taskId ) ) ;
44+
45+ // Retrieve the AgentTask from Dapr state store with strong consistency to get the latest data
46+ AgentTask ? task = await _daprClient . GetStateAsync < AgentTask > (
47+ _stateStoreName ,
48+ key : taskId ,
49+ consistencyMode : ConsistencyMode . Strong ,
50+ metadata : null ,
51+ cancellationToken : cancellationToken ) ;
52+ return task ;
53+ }
54+
55+ /// <summary>
56+ /// Stores or updates a task.
57+ /// </summary>
58+ /// <param name="task">The task to store.</param>
59+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
60+ /// <returns>A task representing the operation.</returns>
61+ public async Task SetTaskAsync ( AgentTask task , CancellationToken cancellationToken = default )
62+ {
63+ if ( task == null ) throw new ArgumentNullException ( nameof ( task ) ) ;
64+ // The task.Id will be used as the key. We save the entire AgentTask object.
65+ // Use strong consistency on write; concurrency defaults to last-write-wins for new entries.
66+ await _daprClient . SaveStateAsync (
67+ _stateStoreName ,
68+ key : task . Id ,
69+ value : task ,
70+ stateOptions : StrongConsistency , // strong consistency ensures durability before ack
71+ metadata : null ,
72+ cancellationToken : cancellationToken
73+ ) ;
74+ // Note: If the task already existed, this will overwrite it (last-write-wins behavior since no ETag used).
75+ }
76+
77+ /// <summary>
78+ /// Updates the status of a task.
79+ /// </summary>
80+ /// <param name="taskId">The ID of the task.</param>
81+ /// <param name="status">The new status.</param>
82+ /// <param name="message">Optional message associated with the status.</param>
83+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
84+ /// <returns>The updated task status.</returns>
85+ public async Task < AgentTaskStatus > UpdateStatusAsync ( string taskId , TaskState status , Message ? message = null , CancellationToken cancellationToken = default )
86+ {
87+ if ( taskId == null ) throw new ArgumentNullException ( nameof ( taskId ) ) ;
88+ // Fetch state with its ETag for concurrency control.
89+ // We use strong consistency to get the latest state and ETag.
90+ var ( existingTask , etag ) = await _daprClient . GetStateAndETagAsync < AgentTask > (
91+ _stateStoreName ,
92+ key : taskId ,
93+ consistencyMode : ConsistencyMode . Strong ,
94+ metadata : null ,
95+ cancellationToken : cancellationToken ) ;
96+ if ( existingTask == null )
97+ {
98+ throw new KeyNotFoundException ( $ "Task with ID '{ taskId } ' not found.") ;
99+ }
100+
101+ // Update the status field of the retrieved task object.
102+ var st = existingTask . Status ;
103+ st . State = status ;
104+ if ( message != null )
105+ {
106+ st . Message = message ;
107+ }
108+
109+ existingTask . Status = st ;
110+
111+ // Attempt to save the updated task back with the ETag for optimistic concurrency.
112+ var stateOptions = new StateOptions
113+ {
114+ Consistency = ConsistencyMode . Strong ,
115+ Concurrency = ConcurrencyMode . FirstWrite // enable first-write-wins
116+ } ;
117+ bool saved = await _daprClient . TrySaveStateAsync (
118+ _stateStoreName ,
119+ key : taskId ,
120+ value : existingTask ,
121+ etag : etag , // use ETag to ensure no concurrent modification
122+ stateOptions : stateOptions ,
123+ metadata : null ,
124+ cancellationToken : cancellationToken
125+ ) ;
126+ if ( ! saved )
127+ {
128+ // The save failed due to an ETag mismatch (concurrent update happened).
129+ throw new InvalidOperationException ( $ "Concurrent update detected for task '{ taskId } '. Update was not saved.") ;
130+ }
131+
132+ return existingTask . Status ;
133+ }
134+
135+ /// <summary>
136+ /// Retrieves push notification configuration for a task.
137+ /// </summary>
138+ /// <param name="taskId">The ID of the task.</param>
139+ /// <param name="notificationConfigId">The ID of the push notification configuration.</param>
140+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
141+ /// <returns>The push notification configuration if found, null otherwise.</returns>
142+ public async Task < TaskPushNotificationConfig ? > GetPushNotificationAsync ( string taskId , string notificationConfigId , CancellationToken cancellationToken = default )
143+ {
144+ if ( string . IsNullOrWhiteSpace ( taskId ) ) throw new ArgumentNullException ( nameof ( taskId ) ) ;
145+ if ( string . IsNullOrWhiteSpace ( notificationConfigId ) ) throw new ArgumentNullException ( nameof ( notificationConfigId ) ) ;
146+
147+ return await _daprClient . GetStateAsync < TaskPushNotificationConfig > (
148+ _stateStoreName ,
149+ key : PushCfgKey ( taskId , notificationConfigId ) ,
150+ consistencyMode : ConsistencyMode . Strong ,
151+ metadata : null ,
152+ cancellationToken : cancellationToken ) ;
153+ }
154+
155+ /// <summary>
156+ /// Stores push notification configuration for a task.
157+ /// </summary>
158+ /// <param name="pushNotificationConfig">The push notification configuration.</param>
159+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
160+ /// <returns>A task representing the operation.</returns>
161+ public async Task SetPushNotificationConfigAsync ( TaskPushNotificationConfig pushNotificationConfig , CancellationToken cancellationToken = default )
162+ {
163+ if ( pushNotificationConfig is null ) throw new ArgumentNullException ( nameof ( pushNotificationConfig ) ) ;
164+
165+ // Adjust these property names if your model differs:
166+ var taskId = pushNotificationConfig . TaskId ?? throw new ArgumentException ( "Config.TaskId is required." ) ;
167+ var configId = pushNotificationConfig . PushNotificationConfig . Id ?? throw new ArgumentException ( "Config.Id is required." ) ;
168+
169+ //Save/Upsert the config itself
170+ await _daprClient . SaveStateAsync (
171+ _stateStoreName ,
172+ key : PushCfgKey ( taskId , configId ) ,
173+ value : pushNotificationConfig ,
174+ stateOptions : StrongConsistency ,
175+ metadata : null ,
176+ cancellationToken : cancellationToken ) ;
177+
178+ // Add the configId to the per-task index with ETag (avoid races)
179+ for ( var attempt = 0 ; attempt < 5 ; attempt ++ )
180+ {
181+ var ( index , etag ) = await _daprClient . GetStateAndETagAsync < string [ ] > (
182+ _stateStoreName ,
183+ key : PushCfgIndexKey ( taskId ) ,
184+ consistencyMode : ConsistencyMode . Strong ,
185+ metadata : null ,
186+ cancellationToken : cancellationToken ) ;
187+
188+ var list = ( index ?? Array . Empty < string > ( ) ) . ToList ( ) ;
189+ if ( ! list . Contains ( configId , StringComparer . Ordinal ) )
190+ list . Add ( configId ) ;
191+
192+ var ok = await _daprClient . TrySaveStateAsync (
193+ _stateStoreName ,
194+ key : PushCfgIndexKey ( taskId ) ,
195+ value : list . ToArray ( ) ,
196+ etag : etag ,
197+ stateOptions : new StateOptions { Consistency = ConsistencyMode . Strong , Concurrency = ConcurrencyMode . FirstWrite } ,
198+ metadata : null ,
199+ cancellationToken : cancellationToken ) ;
200+
201+ if ( ok ) break ;
202+
203+ // small backoff before retry
204+ await Task . Delay ( 50 * ( attempt + 1 ) , cancellationToken ) ;
205+ }
206+ }
207+
208+ /// <summary>
209+ /// Retrieves push notification configuration for a task.
210+ /// </summary>
211+ /// <param name="taskId">The ID of the task.</param>
212+ /// <param name="cancellationToken">A cancellation token that can be used to cancel the operation.</param>
213+ /// <returns>The push notification configuration if found, null otherwise.</returns>
214+ public async Task < IEnumerable < TaskPushNotificationConfig > > GetPushNotificationsAsync ( string taskId , CancellationToken cancellationToken = default )
215+ {
216+ if ( string . IsNullOrWhiteSpace ( taskId ) ) throw new ArgumentNullException ( nameof ( taskId ) ) ;
217+
218+ var ids = await _daprClient . GetStateAsync < string [ ] > (
219+ _stateStoreName ,
220+ key : PushCfgIndexKey ( taskId ) ,
221+ consistencyMode : ConsistencyMode . Strong ,
222+ metadata : null ,
223+ cancellationToken : cancellationToken ) ?? Array . Empty < string > ( ) ;
224+
225+ if ( ids . Length == 0 ) return Array . Empty < TaskPushNotificationConfig > ( ) ;
226+
227+ const int maxParallel = 8 ;
228+ using var gate = new SemaphoreSlim ( maxParallel ) ;
229+ var bag = new ConcurrentBag < TaskPushNotificationConfig > ( ) ;
230+
231+ await Task . WhenAll ( ids . Select ( async id =>
232+ {
233+ await gate . WaitAsync ( cancellationToken ) ;
234+ try
235+ {
236+ var cfg = await _daprClient . GetStateAsync < TaskPushNotificationConfig > (
237+ _stateStoreName ,
238+ key : PushCfgKey ( taskId , id ) ,
239+ consistencyMode : ConsistencyMode . Strong ,
240+ metadata : null ,
241+ cancellationToken : cancellationToken ) ;
242+
243+ if ( cfg is not null ) bag . Add ( cfg ) ;
244+ }
245+ finally
246+ {
247+ gate . Release ( ) ;
248+ }
249+ } ) ) ;
250+
251+ return bag . ToArray ( ) ;
252+ }
253+ }
0 commit comments