11package com .datadog .aiguard ;
22
3+ import static java .util .Collections .singletonMap ;
4+
5+ import com .squareup .moshi .JsonAdapter ;
36import com .squareup .moshi .JsonReader ;
47import com .squareup .moshi .JsonWriter ;
58import com .squareup .moshi .Moshi ;
9+ import com .squareup .moshi .Types ;
610import datadog .communication .http .OkHttpUtils ;
711import datadog .trace .api .Config ;
812import datadog .trace .api .aiguard .AIGuard ;
2024import datadog .trace .bootstrap .instrumentation .api .AgentSpan ;
2125import datadog .trace .bootstrap .instrumentation .api .AgentTracer ;
2226import java .io .IOException ;
27+ import java .lang .annotation .Annotation ;
28+ import java .lang .reflect .Type ;
2329import java .util .Collection ;
24- import java .util .Collections ;
2530import java .util .HashMap ;
2631import java .util .List ;
2732import java .util .Map ;
33+ import java .util .Set ;
2834import java .util .stream .Collectors ;
2935import javax .annotation .Nullable ;
3036import okhttp3 .HttpUrl ;
3642import okhttp3 .ResponseBody ;
3743import okio .BufferedSink ;
3844
45+ /**
46+ * Actual implementation of the SDK for calling the AIGuard REST API , the instance is configured
47+ * during the startup of the agent {@link AIGuardSystem#start()}
48+ */
3949public class AIGuardInternal implements Evaluator {
4050
4151 public static class BadConfigurationException extends RuntimeException {
@@ -87,7 +97,7 @@ static void uninstall() {
8797 this .url = url ;
8898 this .headers = headers ;
8999 this .client = client ;
90- this .moshi = new Moshi .Builder ().build ();
100+ this .moshi = new Moshi .Builder ().add ( new AIGuardFactory ()). build ();
91101 final Config config = Config .get ();
92102 this .meta = mapOf ("service" , config .getServiceName (), "env" , config .getEnv ());
93103 }
@@ -126,21 +136,20 @@ private static String getToolName(final Message current, final List<Message> mes
126136 .map (ToolCall ::getFunction )
127137 .map (Function ::getName )
128138 .collect (Collectors .joining ("," ));
129- } else {
130- // assistant message with tool output (search the linked tool call in reverse order)
131- final String id = current .getToolCallId ();
132- for (int i = messages .size () - 1 ; i >= 0 ; i --) {
133- final Message message = messages .get (i );
134- if (message .getToolCalls () != null ) {
135- for (final ToolCall toolCall : message .getToolCalls ()) {
136- if (toolCall .getId ().equals (id )) {
137- return toolCall .getFunction () == null ? null : toolCall .getFunction ().getName ();
138- }
139+ }
140+ // assistant message with tool output (search the linked tool call in reverse order)
141+ final String id = current .getToolCallId ();
142+ for (int i = messages .size () - 1 ; i >= 0 ; i --) {
143+ final Message message = messages .get (i );
144+ if (message .getToolCalls () != null ) {
145+ for (final ToolCall toolCall : message .getToolCalls ()) {
146+ if (toolCall .getId ().equals (id )) {
147+ return toolCall .getFunction () == null ? null : toolCall .getFunction ().getName ();
139148 }
140149 }
141150 }
142- return null ;
143151 }
152+ return null ;
144153 }
145154
146155 private boolean isBlockingEnabled (final Object isBlockingEnabled ) {
@@ -155,18 +164,17 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
155164 final AgentTracer .TracerAPI tracer = AgentTracer .get ();
156165 final AgentSpan span = tracer .buildSpan (SPAN_NAME , SPAN_NAME ).start ();
157166 try (final AgentScope scope = tracer .activateSpan (span )) {
158- final Message current = messages .get (messages .size () - 1 );
159- if (isToolCall (current )) {
167+ final Message last = messages .get (messages .size () - 1 );
168+ if (isToolCall (last )) {
160169 span .setTag (TARGET_TAG , "tool" );
161- final String toolName = getToolName (current , messages );
170+ final String toolName = getToolName (last , messages );
162171 if (toolName != null ) {
163172 span .setTag (TOOL_TAG , toolName );
164173 }
165174 } else {
166175 span .setTag (TARGET_TAG , "prompt" );
167176 }
168- final Map <String , Object > metaStruct =
169- Collections .singletonMap (META_STRUCT_KEY , truncate (messages ));
177+ final Map <String , Object > metaStruct = singletonMap (META_STRUCT_KEY , truncate (messages ));
170178 span .setMetaStruct (META_STRUCT_TAG , metaStruct );
171179 final Request .Builder request =
172180 new Request .Builder ()
@@ -243,6 +251,69 @@ public static void install(final Evaluator evaluator) {
243251 }
244252 }
245253
254+ static class AIGuardFactory implements JsonAdapter .Factory {
255+
256+ @ Nullable
257+ @ Override
258+ public JsonAdapter <?> create (
259+ final Type type , final Set <? extends Annotation > annotations , final Moshi moshi ) {
260+ final Class <?> rawType = Types .getRawType (type );
261+ if (rawType != AIGuard .Message .class ) {
262+ return null ;
263+ }
264+ return new MessageAdapter (moshi .adapter (AIGuard .ToolCall .class ));
265+ }
266+ }
267+
268+ static class MessageAdapter extends JsonAdapter <Message > {
269+
270+ private final JsonAdapter <AIGuard .ToolCall > toolCallAdapter ;
271+
272+ MessageAdapter (final JsonAdapter <ToolCall > toolCallAdapter ) {
273+ this .toolCallAdapter = toolCallAdapter ;
274+ }
275+
276+ @ Nullable
277+ @ Override
278+ public Message fromJson (JsonReader reader ) throws IOException {
279+ throw new UnsupportedOperationException ("Serializing only adapter" );
280+ }
281+
282+ @ Override
283+ public void toJson (final JsonWriter writer , @ Nullable final Message value ) throws IOException {
284+ if (value == null ) {
285+ writer .nullValue ();
286+ return ;
287+ }
288+ writer .beginObject ();
289+ writeValue (writer , "role" , value .getRole ());
290+ writeValue (writer , "content" , value .getContent ());
291+ writeArray (writer , "tool_calls" , value .getToolCalls ());
292+ writeValue (writer , "tool_call_id" , value .getToolCallId ());
293+ writer .endObject ();
294+ }
295+
296+ private void writeValue (final JsonWriter writer , final String name , final Object value )
297+ throws IOException {
298+ if (value != null ) {
299+ writer .name (name );
300+ writer .jsonValue (value );
301+ }
302+ }
303+
304+ private void writeArray (final JsonWriter writer , final String name , final List <ToolCall > value )
305+ throws IOException {
306+ if (value != null ) {
307+ writer .name (name );
308+ writer .beginArray ();
309+ for (final ToolCall toolCall : value ) {
310+ toolCallAdapter .toJson (writer , toolCall );
311+ }
312+ writer .endArray ();
313+ }
314+ }
315+ }
316+
246317 static class MoshiJsonRequestBody extends RequestBody {
247318
248319 private static final MediaType JSON = MediaType .parse ("application/json" );
0 commit comments