Skip to content

Commit 31dd984

Browse files
Implementation of http client request analysis for OkHttp3
1 parent c442de6 commit 31dd984

File tree

9 files changed

+343
-33
lines changed

9 files changed

+343
-33
lines changed

dd-java-agent/instrumentation-testing/src/main/groovy/datadog/trace/agent/test/base/HttpClientTest.groovy

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,12 +1021,14 @@ abstract class HttpClientTest extends VersionedNamingTestBase {
10211021
{ RequestContext rqCtxt, HttpClientRequest req ->
10221022
if (req.headers?.containsKey('X-AppSec-Test')) {
10231023
final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context
1024-
context.hasAppSecData = true
1025-
activeSpan()
1026-
.setTag('downstream.request.url', req.url)
1027-
.setTag('downstream.request.method', req.method)
1028-
.setTag('downstream.request.headers', JsonOutput.toJson(req.headers))
1029-
.setTag('downstream.request.body', req.body?.text)
1024+
if (context != null) {
1025+
context.hasAppSecData = true
1026+
activeSpan()
1027+
.setTag('downstream.request.url', req.url)
1028+
.setTag('downstream.request.method', req.method)
1029+
.setTag('downstream.request.headers', JsonOutput.toJson(req.headers))
1030+
.setTag('downstream.request.body', req.body?.text)
1031+
}
10301032

10311033
}
10321034
Flow.ResultFlow.empty()
@@ -1035,7 +1037,7 @@ abstract class HttpClientTest extends VersionedNamingTestBase {
10351037
final BiFunction<RequestContext, HttpClientResponse, Flow<Void>> httpClientResponseCb =
10361038
{ RequestContext rqCtxt, HttpClientResponse res ->
10371039
final context = rqCtxt.getData(RequestContextSlot.APPSEC) as Context
1038-
if (context.hasAppSecData) {
1040+
if (context?.hasAppSecData) {
10391041
activeSpan()
10401042
.setTag('downstream.response.status', res.status)
10411043
.setTag('downstream.response.headers', JsonOutput.toJson(res.headers))

dd-java-agent/instrumentation/okhttp/okhttp-2.2/src/main/java/datadog/trace/instrumentation/okhttp2/AppSecInterceptor.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,32 @@
3434
import okio.BufferedSource;
3535
import okio.Okio;
3636
import okio.Sink;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
3739

3840
public class AppSecInterceptor implements Interceptor {
3941

4042
private static final int BODY_PARSING_SIZE_LIMIT = Config.get().getAppSecBodyParsingSizeLimit();
4143

44+
private static final Logger LOGGER = LoggerFactory.getLogger(AppSecInterceptor.class);
45+
4246
@Override
4347
public Response intercept(final Chain chain) throws IOException {
44-
final AgentSpan span = AgentTracer.activeSpan();
45-
final RequestContext ctx = span.getRequestContext();
46-
final long requestId = span.getSpanId();
47-
final boolean sampled = sampleRequest(ctx, requestId);
48-
final Request request = onRequest(span, sampled, chain.request());
49-
final Response response = chain.proceed(request);
50-
return onResponse(span, sampled, response);
48+
try {
49+
final AgentSpan span = AgentTracer.activeSpan();
50+
final RequestContext ctx = span == null ? null : span.getRequestContext();
51+
if (ctx == null) {
52+
return chain.proceed(chain.request());
53+
}
54+
final long requestId = span.getSpanId();
55+
final boolean sampled = sampleRequest(ctx, requestId);
56+
final Request request = onRequest(span, sampled, chain.request());
57+
final Response response = chain.proceed(request);
58+
return onResponse(span, sampled, response);
59+
} catch (final Exception e) {
60+
LOGGER.debug("Failed to intercept request", e);
61+
return chain.proceed(chain.request());
62+
}
5163
}
5264

5365
private Request onRequest(final AgentSpan span, final boolean sampled, final Request request) {
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
package datadog.trace.instrumentation.okhttp3;
2+
3+
import static datadog.trace.api.gateway.Events.EVENTS;
4+
5+
import datadog.appsec.api.blocking.BlockingException;
6+
import datadog.trace.api.Config;
7+
import datadog.trace.api.appsec.HttpClientPayload;
8+
import datadog.trace.api.appsec.HttpClientRequest;
9+
import datadog.trace.api.appsec.HttpClientResponse;
10+
import datadog.trace.api.appsec.MediaType;
11+
import datadog.trace.api.gateway.BlockResponseFunction;
12+
import datadog.trace.api.gateway.CallbackProvider;
13+
import datadog.trace.api.gateway.Flow;
14+
import datadog.trace.api.gateway.RequestContext;
15+
import datadog.trace.api.gateway.RequestContextSlot;
16+
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
17+
import datadog.trace.bootstrap.instrumentation.api.AgentTracer;
18+
import datadog.trace.bootstrap.instrumentation.api.Tags;
19+
import java.io.ByteArrayInputStream;
20+
import java.io.ByteArrayOutputStream;
21+
import java.io.IOException;
22+
import java.util.Collections;
23+
import java.util.HashMap;
24+
import java.util.List;
25+
import java.util.Map;
26+
import java.util.function.BiFunction;
27+
import okhttp3.Headers;
28+
import okhttp3.Interceptor;
29+
import okhttp3.Request;
30+
import okhttp3.RequestBody;
31+
import okhttp3.Response;
32+
import okhttp3.ResponseBody;
33+
import okio.BufferedSink;
34+
import okio.BufferedSource;
35+
import okio.Okio;
36+
import okio.Sink;
37+
import org.slf4j.Logger;
38+
import org.slf4j.LoggerFactory;
39+
40+
public class AppSecInterceptor implements Interceptor {
41+
42+
private static final int BODY_PARSING_SIZE_LIMIT = Config.get().getAppSecBodyParsingSizeLimit();
43+
44+
private static final Logger LOGGER = LoggerFactory.getLogger(AppSecInterceptor.class);
45+
46+
@Override
47+
public Response intercept(final Chain chain) throws IOException {
48+
try {
49+
final AgentSpan span = AgentTracer.activeSpan();
50+
final RequestContext ctx = span == null ? null : span.getRequestContext();
51+
if (ctx == null) {
52+
return chain.proceed(chain.request());
53+
}
54+
final long requestId = span.getSpanId();
55+
final boolean sampled = sampleRequest(ctx, requestId);
56+
final Request request = onRequest(span, sampled, chain.request());
57+
final Response response = chain.proceed(request);
58+
return onResponse(span, sampled, response);
59+
} catch (final Exception e) {
60+
LOGGER.debug("Failed to intercept request", e);
61+
return chain.proceed(chain.request());
62+
}
63+
}
64+
65+
private Request onRequest(final AgentSpan span, final boolean sampled, final Request request) {
66+
Request result = request;
67+
CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
68+
BiFunction<RequestContext, HttpClientRequest, Flow<Void>> requestCb =
69+
cbp.getCallback(EVENTS.httpClientRequest());
70+
if (requestCb == null) {
71+
return request;
72+
}
73+
74+
final RequestBody requestBody = request.body();
75+
final RequestContext ctx = span.getRequestContext();
76+
final long requestId = span.getSpanId();
77+
final String url = span.getTag(Tags.HTTP_URL).toString();
78+
final HttpClientRequest clientRequest =
79+
new HttpClientRequest(requestId, url, request.method(), mapHeaders(request.headers()));
80+
if (sampled && requestBody != null) {
81+
// we are going to effectively read all the request body in memory to be analyzed by the WAF,
82+
// we also modify the outbound request accordingly
83+
final MediaType mediaType = contentType(requestBody);
84+
try {
85+
final long contentLength = requestBody.contentLength();
86+
if (shouldProcessBody(contentLength, mediaType)) {
87+
final byte[] payload = readBody(requestBody, (int) contentLength);
88+
if (payload.length <= BODY_PARSING_SIZE_LIMIT) {
89+
clientRequest.setBody(mediaType, new ByteArrayInputStream(payload));
90+
}
91+
result =
92+
request
93+
.newBuilder()
94+
.method(request.method(), RequestBody.create(requestBody.contentType(), payload))
95+
.build(); // update request
96+
}
97+
} catch (IOException e) {
98+
// ignore it and keep the original request
99+
}
100+
}
101+
publish(ctx, clientRequest, requestCb);
102+
return result;
103+
}
104+
105+
private Response onResponse(
106+
final AgentSpan span, final boolean sampled, final Response response) {
107+
Response result = response;
108+
CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
109+
BiFunction<RequestContext, HttpClientResponse, Flow<Void>> responseCb =
110+
cbp.getCallback(EVENTS.httpClientResponse());
111+
if (responseCb == null) {
112+
return response;
113+
}
114+
final ResponseBody responseBody = response.body();
115+
final RequestContext ctx = span.getRequestContext();
116+
final long requestId = span.getSpanId();
117+
final HttpClientResponse clientResponse =
118+
new HttpClientResponse(requestId, response.code(), mapHeaders(response.headers()));
119+
if (sampled && responseBody != null) {
120+
// we are going to effectively read all the response body in memory to be analyzed by the WAF,
121+
// we also
122+
// modify the inbound response accordingly
123+
final MediaType mediaType = contentType(responseBody);
124+
try {
125+
final long contentLength = responseBody.contentLength();
126+
if (shouldProcessBody(contentLength, mediaType)) {
127+
final byte[] payload = readBody(responseBody, (int) contentLength);
128+
if (payload.length <= BODY_PARSING_SIZE_LIMIT) {
129+
clientResponse.setBody(mediaType, new ByteArrayInputStream(payload));
130+
}
131+
result =
132+
response
133+
.newBuilder()
134+
.body(ResponseBody.create(responseBody.contentType(), payload))
135+
.build();
136+
}
137+
} catch (IOException e) {
138+
// ignore it and keep the original response
139+
}
140+
}
141+
142+
publish(ctx, clientResponse, responseCb);
143+
return result;
144+
}
145+
146+
private <P extends HttpClientPayload> void publish(
147+
final RequestContext ctx,
148+
final P request,
149+
final BiFunction<RequestContext, P, Flow<Void>> callback) {
150+
Flow<Void> flow = callback.apply(ctx, request);
151+
Flow.Action action = flow.getAction();
152+
if (action instanceof Flow.Action.RequestBlockingAction) {
153+
BlockResponseFunction brf = ctx.getBlockResponseFunction();
154+
if (brf != null) {
155+
Flow.Action.RequestBlockingAction rba = (Flow.Action.RequestBlockingAction) action;
156+
brf.tryCommitBlockingResponse(
157+
ctx.getTraceSegment(),
158+
rba.getStatusCode(),
159+
rba.getBlockingContentType(),
160+
rba.getExtraHeaders());
161+
}
162+
throw new BlockingException("Blocked request (for http downstream request)");
163+
}
164+
}
165+
166+
private boolean sampleRequest(final RequestContext ctx, final long requestId) {
167+
// Check if the current http request was sampled
168+
CallbackProvider cbp = AgentTracer.get().getCallbackProvider(RequestContextSlot.APPSEC);
169+
BiFunction<RequestContext, Long, Flow<Boolean>> samplingCb =
170+
cbp.getCallback(EVENTS.httpClientSampling());
171+
if (samplingCb == null) {
172+
return false;
173+
}
174+
final Flow<Boolean> sampled = samplingCb.apply(ctx, requestId);
175+
return sampled.getResult() != null && sampled.getResult();
176+
}
177+
178+
/**
179+
* Ensure we are only consuming payloads we can safely deserialize with a bounded size to prevent
180+
* from OOM
181+
*/
182+
private boolean shouldProcessBody(final long contentLength, final MediaType mediaType) {
183+
if (contentLength <= 0) {
184+
return false; // prevent from copying from unbounded source (just to be safe)
185+
}
186+
if (BODY_PARSING_SIZE_LIMIT <= 0) {
187+
return false; // effectively disabled by configuration
188+
}
189+
if (contentLength > BODY_PARSING_SIZE_LIMIT) {
190+
return false;
191+
}
192+
return mediaType.isDeserializable();
193+
}
194+
195+
private byte[] readBody(final RequestBody body, final int contentLength) throws IOException {
196+
final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength);
197+
try (final BufferedSink sink = Okio.buffer(Okio.sink(buffer))) {
198+
body.writeTo(sink);
199+
}
200+
return buffer.toByteArray();
201+
}
202+
203+
private byte[] readBody(final ResponseBody body, final int contentLength) throws IOException {
204+
final ByteArrayOutputStream buffer = new ByteArrayOutputStream(contentLength);
205+
try (final BufferedSource source = body.source();
206+
final Sink sink = Okio.sink(buffer)) {
207+
source.readAll(sink);
208+
}
209+
return buffer.toByteArray();
210+
}
211+
212+
private Map<String, List<String>> mapHeaders(final Headers headers) {
213+
if (headers == null) {
214+
return Collections.emptyMap();
215+
}
216+
final Map<String, List<String>> result = new HashMap<>(headers.size());
217+
for (final String name : headers.names()) {
218+
result.put(name, headers.values(name));
219+
}
220+
return result;
221+
}
222+
223+
private MediaType contentType(final RequestBody body) {
224+
return MediaType.parse(
225+
body == null || body.contentType() == null ? null : body.contentType().toString());
226+
}
227+
228+
private MediaType contentType(final ResponseBody body) {
229+
return MediaType.parse(
230+
body == null || body.contentType() == null ? null : body.contentType().toString());
231+
}
232+
}

dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttp3Instrumentation.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import com.google.auto.service.AutoService;
88
import datadog.trace.agent.tooling.Instrumenter;
99
import datadog.trace.agent.tooling.InstrumenterModule;
10+
import datadog.trace.bootstrap.ActiveSubsystems;
1011
import net.bytebuddy.asm.Advice;
1112
import okhttp3.Interceptor;
1213
import okhttp3.OkHttpClient;
@@ -30,6 +31,7 @@ public String[] helperClassNames() {
3031
packageName + ".RequestBuilderInjectAdapter",
3132
packageName + ".OkHttpClientDecorator",
3233
packageName + ".TracingInterceptor",
34+
packageName + ".AppSecInterceptor",
3335
};
3436
}
3537

@@ -51,6 +53,9 @@ public static void addTracingInterceptor(
5153
}
5254
final TracingInterceptor interceptor = new TracingInterceptor();
5355
builder.addInterceptor(interceptor);
56+
if (ActiveSubsystems.APPSEC_ACTIVE) {
57+
builder.addInterceptor(new AppSecInterceptor());
58+
}
5459
}
5560
}
5661
}

dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/main/java/datadog/trace/instrumentation/okhttp3/OkHttpClientDecorator.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package datadog.trace.instrumentation.okhttp3;
22

3+
import datadog.trace.bootstrap.instrumentation.api.AgentSpan;
34
import datadog.trace.bootstrap.instrumentation.api.UTF8BytesString;
45
import datadog.trace.bootstrap.instrumentation.decorator.HttpClientDecorator;
56
import java.net.URI;
@@ -58,4 +59,10 @@ protected String getRequestHeader(Request request, String headerName) {
5859
protected String getResponseHeader(Response response, String headerName) {
5960
return response.header(headerName);
6061
}
62+
63+
/** Overridden by {@link AppSecInterceptor} */
64+
@Override
65+
protected void onHttpClientRequest(AgentSpan span, String url) {
66+
// do nothing
67+
}
6168
}

dd-java-agent/instrumentation/okhttp/okhttp-3.0/src/test/groovy/OkHttp3AsyncTest.groovy

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ import static java.util.concurrent.TimeUnit.SECONDS
1919
abstract class OkHttp3AsyncTest extends OkHttp3Test {
2020
@Override
2121
int doRequest(String method, URI uri, Map<String, String> headers, String body, Closure callback) {
22-
def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse("text/plain"), body) : null
22+
final contentType = headers.remove("Content-Type")
23+
def reqBody = HttpMethod.requiresRequestBody(method) ? RequestBody.create(MediaType.parse(contentType ?: "text/plain"), body) : null
2324
def request = new Request.Builder()
2425
.url(uri.toURL())
2526
.method(method, reqBody)
@@ -33,13 +34,13 @@ abstract class OkHttp3AsyncTest extends OkHttp3Test {
3334
client.newCall(request).enqueue(new Callback() {
3435
void onResponse(Call call, Response response) {
3536
responseRef.set(response)
36-
callback?.call()
37+
callback?.call(response.body().byteStream())
3738
latch.countDown()
3839
}
3940

4041
void onFailure(Call call, IOException e) {
4142
exRef.set(e)
42-
callback?.call()
43+
callback?.call(e)
4344
latch.countDown()
4445
}
4546
})

0 commit comments

Comments
 (0)