Skip to content

Commit 657b1e1

Browse files
Initial implementation of the AI Guard SDK
1 parent 4cf7670 commit 657b1e1

File tree

17 files changed

+1382
-1
lines changed

17 files changed

+1382
-1
lines changed

.github/CODEOWNERS

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
# @DataDog/asm-java (AppSec/IAST)
4848
/buildSrc/call-site-instrumentation-plugin/ @DataDog/asm-java
49+
/dd-java-agent/agent-aiguard/ @DataDog/asm-java
4950
/dd-java-agent/agent-iast/ @DataDog/asm-java
5051
/dd-java-agent/appsec/appsec-test-fixtures/ @DataDog/asm-java
5152
/dd-java-agent/instrumentation/*iast* @DataDog/asm-java
@@ -58,6 +59,7 @@
5859
/dd-smoke-tests/spring-security/ @DataDog/asm-java
5960
/dd-java-agent/instrumentation/commons-fileupload/ @DataDog/asm-java
6061
/dd-java-agent/instrumentation/spring/spring-security/ @DataDog/asm-java
62+
/dd-trace-api/src/main/java/datadog/trace/api/aiguard/ @DataDog/asm-java
6163
/dd-trace-api/src/main/java/datadog/trace/api/EventTracker.java @DataDog/asm-java
6264
/internal-api/src/main/java/datadog/trace/api/gateway/ @DataDog/asm-java
6365
**/appsec/ @DataDog/asm-java
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
plugins {
2+
id 'com.gradleup.shadow'
3+
}
4+
5+
apply from: "$rootDir/gradle/java.gradle"
6+
apply from: "$rootDir/gradle/version.gradle"
7+
8+
java {
9+
sourceCompatibility = JavaVersion.VERSION_1_8
10+
targetCompatibility = JavaVersion.VERSION_1_8
11+
}
12+
13+
dependencies {
14+
api libs.slf4j
15+
implementation libs.moshi
16+
implementation libs.okhttp
17+
18+
api project(':dd-trace-api')
19+
implementation project(':internal-api')
20+
implementation project(':communication')
21+
22+
testImplementation project(':utils:test-utils')
23+
testImplementation('org.skyscreamer:jsonassert:1.5.1')
24+
}
25+
26+
shadowJar {
27+
dependencies deps.excludeShared
28+
}
29+
30+
jar {
31+
archiveClassifier = 'unbundled'
32+
}
33+
34+
ext {
35+
minimumBranchCoverage = 0.6
36+
minimumInstructionCoverage = 0.8
37+
excludedClassesCoverage = []
38+
excludedClassesBranchCoverage = []
39+
excludedClassesInstructionCoverage = []
40+
}
41+
42+
spotless {
43+
java {
44+
target 'src/**/*.java'
45+
}
46+
}
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
package com.datadog.aiguard;
2+
3+
import com.squareup.moshi.JsonReader;
4+
import com.squareup.moshi.JsonWriter;
5+
import com.squareup.moshi.Moshi;
6+
import datadog.communication.http.OkHttpUtils;
7+
import datadog.trace.api.Config;
8+
import datadog.trace.api.aiguard.AIGuard;
9+
import datadog.trace.api.aiguard.AIGuard.AIGuardAbortError;
10+
import datadog.trace.api.aiguard.AIGuard.AIGuardClientError;
11+
import datadog.trace.api.aiguard.AIGuard.Action;
12+
import datadog.trace.api.aiguard.AIGuard.Evaluation;
13+
import datadog.trace.api.aiguard.AIGuard.Message;
14+
import datadog.trace.api.aiguard.AIGuard.Options;
15+
import datadog.trace.api.aiguard.AIGuard.ToolCall;
16+
import datadog.trace.api.aiguard.AIGuard.ToolCall.Function;
17+
import datadog.trace.api.aiguard.Evaluator;
18+
import datadog.trace.api.aiguard.noop.NoOpEvaluator;
19+
import datadog.trace.bootstrap.instrumentation.api.AgentScope;
20+
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
21+
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
22+
import java.io.IOException;
23+
import java.util.Collection;
24+
import java.util.HashMap;
25+
import java.util.List;
26+
import java.util.Map;
27+
import java.util.stream.Collectors;
28+
import javax.annotation.Nullable;
29+
import okhttp3.HttpUrl;
30+
import okhttp3.MediaType;
31+
import okhttp3.OkHttpClient;
32+
import okhttp3.Request;
33+
import okhttp3.RequestBody;
34+
import okhttp3.Response;
35+
import okhttp3.ResponseBody;
36+
import okio.BufferedSink;
37+
38+
public class AIGuardInternal implements Evaluator {
39+
40+
static final String SPAN_NAME = "ai_guard";
41+
static final String TARGET_TAG = "ai_guard.target";
42+
static final String TOOL_TAG = "ai_guard.tool";
43+
static final String ACTION_TAG = "ai_guard.action";
44+
static final String REASON_TAG = "ai_guard.reason";
45+
static final String BLOCKED_TAG = "ai_guard.blocked";
46+
static final String META_STRUCT_TAG = "ai_guard";
47+
static final String META_STRUCT_KEY = "messages";
48+
49+
public static void install() {
50+
final Config config = Config.get();
51+
final String apiKey = config.getApiKey();
52+
final String appKey = config.getApplicationKey();
53+
if (isEmpty(apiKey) || isEmpty(appKey)) {
54+
throw new RuntimeException(
55+
"AI Guard: Missing api and/or application key, use DD_API_KEY and DD_APP_KEY");
56+
}
57+
String endpoint = config.getAiGuardEndpoint();
58+
if (isEmpty(endpoint)) {
59+
endpoint = String.format("https://app.%s/api/v2/ai-guard", config.getSite());
60+
}
61+
final Map<String, String> headers = new HashMap<>(2);
62+
headers.put("DD-API-KEY", apiKey);
63+
headers.put("DD-APP-KEY", appKey);
64+
final HttpUrl url = HttpUrl.get(endpoint).newBuilder().addPathSegment("evaluate").build();
65+
final int timeout = config.getAiGuardTimeout();
66+
final OkHttpClient client = buildClient(url, timeout);
67+
Installer.install(new AIGuardInternal(url, headers, client));
68+
}
69+
70+
/** Used by tests to reset status */
71+
static void uninstall() {
72+
Installer.install(new NoOpEvaluator());
73+
}
74+
75+
private final HttpUrl url;
76+
private final Moshi moshi;
77+
private final OkHttpClient client;
78+
private final Map<String, String> meta;
79+
private final Map<String, String> headers;
80+
81+
AIGuardInternal(final HttpUrl url, final Map<String, String> headers, final OkHttpClient client) {
82+
this.url = url;
83+
this.headers = headers;
84+
this.client = client;
85+
this.moshi = new Moshi.Builder().build();
86+
final Config config = Config.get();
87+
this.meta = new HashMap<>(2);
88+
this.meta.put("service", config.getServiceName());
89+
this.meta.put("env", config.getEnv());
90+
}
91+
92+
private static List<Message> truncate(List<Message> messages) {
93+
final Config config = Config.get();
94+
final int maxMessages = config.getAiGuardMaxMessagesLength();
95+
if (messages.size() > maxMessages) {
96+
messages = messages.subList(messages.size() - maxMessages, messages.size());
97+
}
98+
final int maxContent = config.getAiGuardMaxContentSize();
99+
for (int i = 0; i < messages.size(); i++) {
100+
Message source = messages.get(i);
101+
if (source.getContent() != null && source.getContent().length() > maxContent) {
102+
source =
103+
new Message(
104+
source.getRole(),
105+
source.getContent().substring(0, maxContent),
106+
source.getToolCalls(),
107+
source.getToolCallId());
108+
messages.set(i, source);
109+
}
110+
}
111+
return messages;
112+
}
113+
114+
private static boolean isToolCall(final Message message) {
115+
return message.getToolCalls() != null || message.getToolCallId() != null;
116+
}
117+
118+
private static String getToolName(final Message current, final List<Message> messages) {
119+
if (current.getToolCalls() != null) {
120+
// assistant message with tool calls
121+
return current.getToolCalls().stream()
122+
.map(ToolCall::getFunction)
123+
.map(Function::getName)
124+
.collect(Collectors.joining(","));
125+
} else {
126+
// assistant message with tool output (search the linked tool call in reverse order)
127+
final String id = current.getToolCallId();
128+
for (int i = messages.size() - 1; i >= 0; i--) {
129+
final Message message = messages.get(i);
130+
if (message.getToolCalls() != null) {
131+
for (final ToolCall toolCall : message.getToolCalls()) {
132+
if (toolCall.getId().equals(id)) {
133+
return toolCall.getFunction() == null ? null : toolCall.getFunction().getName();
134+
}
135+
}
136+
}
137+
}
138+
return null;
139+
}
140+
}
141+
142+
private boolean isBlockingEnabled(final Object isBlockingEnabled) {
143+
return isBlockingEnabled != null && isBlockingEnabled.toString().equalsIgnoreCase("true");
144+
}
145+
146+
@Override
147+
public Evaluation evaluate(final List<Message> messages, final Options options) {
148+
if (messages == null || messages.isEmpty()) {
149+
throw new IllegalArgumentException("messages must not be empty");
150+
}
151+
final AgentTracer.TracerAPI tracer = AgentTracer.get();
152+
final AgentSpan span = tracer.buildSpan(SPAN_NAME, SPAN_NAME).start();
153+
try (final AgentScope scope = tracer.activateSpan(span)) {
154+
final Message current = messages.get(messages.size() - 1);
155+
if (isToolCall(current)) {
156+
span.setTag(TARGET_TAG, "tool");
157+
final String toolName = getToolName(current, messages);
158+
if (toolName != null) {
159+
span.setTag(TOOL_TAG, toolName);
160+
}
161+
} else {
162+
span.setTag(TARGET_TAG, "prompt");
163+
}
164+
final Map<String, Object> metaStruct = new HashMap<>(1);
165+
metaStruct.put(META_STRUCT_KEY, truncate(messages));
166+
span.setMetaStruct(META_STRUCT_TAG, metaStruct);
167+
final Request.Builder request =
168+
new Request.Builder()
169+
.url(url)
170+
.method("POST", new MoshiJsonRequestBody(moshi, messages, meta));
171+
headers.forEach(request::header);
172+
try (final Response response = client.newCall(request.build()).execute()) {
173+
final Map<String, Object> result = parseResponseBody(response);
174+
final String actionStr = (String) result.get("action");
175+
if (actionStr == null) {
176+
throw new IllegalArgumentException("action field is missing in the response");
177+
}
178+
final Action action = Action.valueOf(actionStr);
179+
final String reason = (String) result.get("reason");
180+
span.setTag(ACTION_TAG, action);
181+
span.setTag(REASON_TAG, reason);
182+
final boolean blockingEnabled = isBlockingEnabled(result.get("is_blocking_enabled"));
183+
if (blockingEnabled && options.block() && action != Action.ALLOW) {
184+
span.setTag(BLOCKED_TAG, true);
185+
throw new AIGuardAbortError(action, reason);
186+
}
187+
return new Evaluation(action, reason);
188+
}
189+
} catch (AIGuardAbortError | AIGuardClientError e) {
190+
span.addThrowable(e);
191+
throw e;
192+
} catch (final Exception e) {
193+
final AIGuardClientError error =
194+
new AIGuardClientError("AI Guard service returned unexpected response", e);
195+
span.addThrowable(error);
196+
throw error;
197+
}
198+
}
199+
200+
@SuppressWarnings("unchecked")
201+
private Map<String, Object> parseResponseBody(final Response response) throws IOException {
202+
final ResponseBody body = response.body();
203+
if (body == null) {
204+
throw fail(response.code(), null);
205+
}
206+
final JsonReader reader = JsonReader.of(body.source());
207+
final Map<?, ?> parsedBody = moshi.adapter(Map.class).fromJson(reader);
208+
final Object errors = parsedBody.get("errors");
209+
if (errors != null) {
210+
throw fail(response.code(), errors);
211+
}
212+
final Map<?, ?> data = (Map<?, ?>) parsedBody.get("data");
213+
return (Map<String, Object>) data.get("attributes");
214+
}
215+
216+
private AIGuardClientError fail(final int statusCode, final Object errors) {
217+
return new AIGuardClientError("AI Guard service call failed, status" + statusCode, errors);
218+
}
219+
220+
private static OkHttpClient buildClient(final HttpUrl url, final long timeout) {
221+
return OkHttpUtils.buildHttpClient(url, timeout).newBuilder().build();
222+
}
223+
224+
private static boolean isEmpty(final String value) {
225+
return value == null || value.isEmpty();
226+
}
227+
228+
private static class Installer extends AIGuard {
229+
public static void install(final Evaluator evaluator) {
230+
AIGuard.EVALUATOR = evaluator;
231+
}
232+
}
233+
234+
static class MoshiJsonRequestBody extends RequestBody {
235+
236+
private static final MediaType JSON = MediaType.parse("application/json");
237+
238+
private final Moshi moshi;
239+
private final Map<String, String> meta;
240+
private final Collection<Message> messages;
241+
242+
public MoshiJsonRequestBody(
243+
final Moshi moshi, final Collection<Message> messages, final Map<String, String> meta) {
244+
this.moshi = moshi;
245+
this.messages = messages;
246+
this.meta = meta;
247+
}
248+
249+
@Nullable
250+
@Override
251+
public MediaType contentType() {
252+
return JSON;
253+
}
254+
255+
@Override
256+
public void writeTo(final BufferedSink sink) throws IOException {
257+
final JsonWriter writer = JsonWriter.of(sink);
258+
writer.beginObject(); // request
259+
writer.name("data");
260+
writer.beginObject(); // data
261+
writer.name("attributes");
262+
writer.beginObject(); // attributes
263+
writer.name("messages");
264+
moshi.adapter(Object.class).toJson(writer, messages);
265+
writer.name("meta");
266+
writer.jsonValue(meta);
267+
writer.endObject(); // attributes
268+
writer.endObject(); // data
269+
writer.endObject(); // request
270+
}
271+
}
272+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package com.datadog.aiguard;
2+
3+
public abstract class AIGuardSystem {
4+
5+
private AIGuardSystem() {}
6+
7+
public static void start() {
8+
initializeSDK();
9+
}
10+
11+
private static void initializeSDK() {
12+
AIGuardInternal.install();
13+
}
14+
}

0 commit comments

Comments
 (0)