Skip to content

Commit 4a9b722

Browse files
Ensure we properly serialize messages in the meta struct
1 parent c3d2911 commit 4a9b722

File tree

6 files changed

+251
-15
lines changed

6 files changed

+251
-15
lines changed

communication/src/main/java/datadog/communication/serialization/Codec.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
package datadog.communication.serialization;
22

3+
import datadog.communication.serialization.custom.aiguard.MessageWriter;
4+
import datadog.communication.serialization.custom.aiguard.ToolCallWriter;
35
import datadog.communication.serialization.custom.stacktrace.StackTraceEventFrameWriter;
46
import datadog.communication.serialization.custom.stacktrace.StackTraceEventWriter;
7+
import datadog.trace.api.aiguard.AIGuard;
58
import datadog.trace.util.stacktrace.StackTraceEvent;
69
import datadog.trace.util.stacktrace.StackTraceFrame;
710
import java.nio.ByteBuffer;
@@ -19,6 +22,8 @@ public final class Codec extends ClassValue<ValueWriter<?>> {
1922
new Object[][] {
2023
{StackTraceEvent.class, new StackTraceEventWriter()},
2124
{StackTraceFrame.class, new StackTraceEventFrameWriter()},
25+
{AIGuard.Message.class, new MessageWriter()},
26+
{AIGuard.ToolCall.class, new ToolCallWriter()},
2227
})
2328
.collect(Collectors.toMap(data -> (Class<?>) data[0], data -> (ValueWriter<?>) data[1]));
2429

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
import datadog.trace.util.Strings;
8+
import java.util.List;
9+
10+
public class MessageWriter implements ValueWriter<AIGuard.Message> {
11+
12+
@Override
13+
public void write(
14+
final AIGuard.Message value, final Writable writable, final EncodingCache encodingCache) {
15+
final int[] size = {0};
16+
final boolean hasRole = isNotBlank(value.getRole(), size);
17+
final boolean hasContent = isNotBlank(value.getContent(), size);
18+
final boolean hasToolCallId = isNotBlank(value.getToolCallId(), size);
19+
final boolean hasToolCalls = isNotEmpty(value.getToolCalls(), size);
20+
writable.startMap(size[0]);
21+
writeString(hasRole, "role", value.getRole(), writable, encodingCache);
22+
writeString(hasContent, "content", value.getContent(), writable, encodingCache);
23+
writeString(hasToolCallId, "tool_call_id", value.getToolCallId(), writable, encodingCache);
24+
writeToolCallArray(hasToolCalls, "tool_calls", value.getToolCalls(), writable, encodingCache);
25+
}
26+
27+
private static void writeString(
28+
final boolean present,
29+
final String key,
30+
final String value,
31+
final Writable writable,
32+
final EncodingCache encodingCache) {
33+
if (present) {
34+
writable.writeString(key, encodingCache);
35+
writable.writeString(value, encodingCache);
36+
}
37+
}
38+
39+
private static void writeToolCallArray(
40+
final boolean present,
41+
final String key,
42+
final List<AIGuard.ToolCall> values,
43+
final Writable writable,
44+
final EncodingCache encodingCache) {
45+
if (present) {
46+
writable.writeString(key, encodingCache);
47+
writable.startArray(values.size());
48+
for (final AIGuard.ToolCall toolCall : values) {
49+
writable.writeObject(toolCall, encodingCache);
50+
}
51+
}
52+
}
53+
54+
private static boolean isNotBlank(final String value, final int[] nonBlankCount) {
55+
final boolean hasText = Strings.isNotBlank(value);
56+
if (hasText) {
57+
nonBlankCount[0]++;
58+
}
59+
return hasText;
60+
}
61+
62+
private static boolean isNotEmpty(final List<?> value, final int[] nonEmptyCount) {
63+
final boolean nonEmpty = value != null && !value.isEmpty();
64+
if (nonEmpty) {
65+
nonEmptyCount[0]++;
66+
}
67+
return nonEmpty;
68+
}
69+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package datadog.communication.serialization.custom.aiguard;
2+
3+
import datadog.communication.serialization.EncodingCache;
4+
import datadog.communication.serialization.ValueWriter;
5+
import datadog.communication.serialization.Writable;
6+
import datadog.trace.api.aiguard.AIGuard;
7+
8+
public class ToolCallWriter implements ValueWriter<AIGuard.ToolCall> {
9+
10+
@Override
11+
public void write(
12+
final AIGuard.ToolCall value, final Writable writable, final EncodingCache encodingCache) {
13+
writable.startMap(2);
14+
writable.writeString("id", encodingCache);
15+
writable.writeString(value.getId(), encodingCache);
16+
writable.writeString("function", encodingCache);
17+
18+
final AIGuard.ToolCall.Function function = value.getFunction();
19+
if (function != null) {
20+
writable.startMap(2);
21+
writable.writeString("name", encodingCache);
22+
writable.writeString(function.getName(), encodingCache);
23+
writable.writeString("arguments", encodingCache);
24+
writable.writeString(function.getArguments(), encodingCache);
25+
}
26+
}
27+
}
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package datadog.communication.serialization.aiguard;
2+
3+
import static org.hamcrest.MatcherAssert.assertThat;
4+
import static org.hamcrest.Matchers.equalTo;
5+
import static org.hamcrest.Matchers.hasSize;
6+
7+
import datadog.communication.serialization.EncodingCache;
8+
import datadog.communication.serialization.GrowableBuffer;
9+
import datadog.communication.serialization.msgpack.MsgPackWriter;
10+
import datadog.trace.api.aiguard.AIGuard;
11+
import java.io.IOException;
12+
import java.nio.charset.StandardCharsets;
13+
import java.util.HashMap;
14+
import java.util.List;
15+
import java.util.Map;
16+
import java.util.function.Function;
17+
import java.util.stream.Collectors;
18+
import org.junit.jupiter.api.BeforeEach;
19+
import org.junit.jupiter.api.Test;
20+
import org.msgpack.core.MessagePack;
21+
import org.msgpack.core.MessageUnpacker;
22+
import org.msgpack.value.Value;
23+
24+
public class MessageWriterTest {
25+
26+
private EncodingCache encodingCache;
27+
private GrowableBuffer buffer;
28+
private MsgPackWriter writer;
29+
30+
@BeforeEach
31+
public void setup() {
32+
final HashMap<CharSequence, byte[]> cache = new HashMap<>();
33+
encodingCache =
34+
string -> cache.computeIfAbsent(string, s -> s.toString().getBytes(StandardCharsets.UTF_8));
35+
buffer = new GrowableBuffer(1024);
36+
writer = new MsgPackWriter(buffer);
37+
}
38+
39+
@Test
40+
public void testWriteMessage() throws IOException {
41+
final AIGuard.Message message = AIGuard.Message.message("user", "What day is today?");
42+
43+
writer.writeObject(message, encodingCache);
44+
45+
try (final MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
46+
final Map<String, String> value = asStringValueMap(unpacker.unpackValue().asMapValue());
47+
assertThat(value.size(), equalTo(2));
48+
assertThat(value.get("role"), equalTo("user"));
49+
assertThat(value.get("content"), equalTo("What day is today?"));
50+
}
51+
}
52+
53+
@Test
54+
public void testWriteToolCall() throws IOException {
55+
final AIGuard.Message message =
56+
AIGuard.Message.assistant(
57+
AIGuard.ToolCall.toolCall("call_1", "function_1", "args_1"),
58+
AIGuard.ToolCall.toolCall("call_2", "function_2", "args_2"));
59+
60+
writer.writeObject(message, encodingCache);
61+
62+
try (final MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
63+
final Map<String, Value> value = asStringKeyMap(unpacker.unpackValue());
64+
assertThat(value.size(), equalTo(2));
65+
assertThat(asString(value.get("role")), equalTo("assistant"));
66+
67+
final List<Value> toolCalls = value.get("tool_calls").asArrayValue().list();
68+
assertThat(toolCalls, hasSize(2));
69+
70+
final Map<String, Value> firstCall = asStringKeyMap(toolCalls.get(0));
71+
assertThat(asString(firstCall.get("id")), equalTo("call_1"));
72+
final Map<String, String> firstFunction = asStringValueMap(firstCall.get("function"));
73+
assertThat(firstFunction.get("name"), equalTo("function_1"));
74+
assertThat(firstFunction.get("arguments"), equalTo("args_1"));
75+
76+
final Map<String, Value> secondCall = asStringKeyMap(toolCalls.get(1));
77+
assertThat(asString(secondCall.get("id")), equalTo("call_2"));
78+
final Map<String, String> secondFunction = asStringValueMap(secondCall.get("function"));
79+
assertThat(secondFunction.get("name"), equalTo("function_2"));
80+
assertThat(secondFunction.get("arguments"), equalTo("args_2"));
81+
}
82+
}
83+
84+
@Test
85+
public void testWriteToolOutput() throws IOException {
86+
final AIGuard.Message message = AIGuard.Message.tool("call_1", "output");
87+
88+
writer.writeObject(message, encodingCache);
89+
90+
try (final MessageUnpacker unpacker = MessagePack.newDefaultUnpacker(buffer.slice())) {
91+
final Map<String, Value> value = asStringKeyMap(unpacker.unpackValue());
92+
assertThat(value.size(), equalTo(3));
93+
assertThat(asString(value.get("role")), equalTo("tool"));
94+
assertThat(asString(value.get("tool_call_id")), equalTo("call_1"));
95+
assertThat(asString(value.get("content")), equalTo("output"));
96+
}
97+
}
98+
99+
private <K, V> Map<K, V> mapValue(
100+
final Value values,
101+
final Function<Value, K> keyMapper,
102+
final Function<Value, V> valueMapper) {
103+
return values.asMapValue().entrySet().stream()
104+
.collect(
105+
Collectors.toMap(
106+
entry -> keyMapper.apply(entry.getKey()),
107+
entry -> valueMapper.apply(entry.getValue())));
108+
}
109+
110+
private Map<String, Value> asStringKeyMap(final Value values) {
111+
return mapValue(values, this::asString, Function.identity());
112+
}
113+
114+
private Map<String, String> asStringValueMap(final Value values) {
115+
return mapValue(values, this::asString, this::asString);
116+
}
117+
118+
private String asString(final Value value) {
119+
return value.asStringValue().asString();
120+
}
121+
}

dd-java-agent/agent-aiguard/src/main/java/com/datadog/aiguard/AIGuardInternal.java

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.io.IOException;
2727
import java.lang.annotation.Annotation;
2828
import java.lang.reflect.Type;
29+
import java.util.ArrayList;
2930
import java.util.Collection;
3031
import java.util.HashMap;
3132
import java.util.List;
@@ -104,14 +105,16 @@ static void uninstall() {
104105
this.meta = mapOf("service", config.getServiceName(), "env", config.getEnv());
105106
}
106107

107-
private static List<Message> truncate(List<Message> messages) {
108+
/**
109+
* Creates a deep copy of the messages before storing them in the metastruct to avoid concurrent
110+
* modifications prior to trace serialization.
111+
*/
112+
private static List<Message> messagesForMetaStruct(List<Message> messages) {
108113
final Config config = Config.get();
109-
final int maxMessages = config.getAiGuardMaxMessagesLength();
110-
if (messages.size() > maxMessages) {
111-
messages = messages.subList(messages.size() - maxMessages, messages.size());
112-
}
114+
final int size = Math.min(messages.size(), config.getAiGuardMaxMessagesLength());
115+
final List<Message> result = new ArrayList<>(size);
113116
final int maxContent = config.getAiGuardMaxContentSize();
114-
for (int i = 0; i < messages.size(); i++) {
117+
for (int i = 0; i < size; i++) {
115118
Message source = messages.get(i);
116119
final String content = source.getContent();
117120
if (content != null && content.length() > maxContent) {
@@ -121,10 +124,10 @@ private static List<Message> truncate(List<Message> messages) {
121124
content.substring(0, maxContent),
122125
source.getToolCalls(),
123126
source.getToolCallId());
124-
messages.set(i, source);
125127
}
128+
result.add(source);
126129
}
127-
return messages;
130+
return result;
128131
}
129132

130133
private static boolean isToolCall(final Message message) {
@@ -181,7 +184,8 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
181184
} else {
182185
span.setTag(TARGET_TAG, "prompt");
183186
}
184-
final Map<String, Object> metaStruct = singletonMap(META_STRUCT_KEY, truncate(messages));
187+
final Map<String, Object> metaStruct =
188+
singletonMap(META_STRUCT_KEY, messagesForMetaStruct(messages));
185189
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
186190
final Request.Builder request =
187191
new Request.Builder()

dd-smoke-tests/appsec/springboot/src/test/groovy/datadog/smoketest/appsec/AIGuardSmokeTest.groovy

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest {
4343
final action = test.action as String
4444
final reason = test.reason as String
4545
def request = new Request.Builder()
46-
.url("http://localhost:${httpPort}/aiguard${test.endpoint}")
47-
.header('X-Blocking-Enabled', "${blocking}")
48-
.get()
49-
.build()
46+
.url("http://localhost:${httpPort}/aiguard${test.endpoint}")
47+
.header('X-Blocking-Enabled', "${blocking}")
48+
.get()
49+
.build()
5050

5151
when:
5252
final response = client.newCall(request).execute()
@@ -65,11 +65,21 @@ class AIGuardSmokeTest extends AbstractAppSecServerSmokeTest {
6565
and:
6666
waitForTraceCount(2) // default call + internal API mock
6767
final span = traces*.spans
68-
?.flatten()
69-
?.find { it.resource == 'ai_guard' } as DecodedSpan
68+
?.flatten()
69+
?.find { it.resource == 'ai_guard' } as DecodedSpan
7070
assert span.meta.get('ai_guard.action') == action
7171
assert span.meta.get('ai_guard.reason') == reason
7272
assert span.meta.get('ai_guard.target') == 'prompt'
73+
final messages = span.metaStruct.get('ai_guard').messages as List<Map<String, Object>>
74+
assert messages.size() == 2
75+
messages[0].with {
76+
assert role == 'system'
77+
assert content == 'You are a beautiful AI'
78+
}
79+
messages[1].with {
80+
assert role == 'user'
81+
assert content != null
82+
}
7383

7484
where:
7585
test << testSuite()

0 commit comments

Comments
 (0)