Skip to content

Commit 6d5c67d

Browse files
Added post-processing limit
1 parent 94715ff commit 6d5c67d

File tree

3 files changed

+88
-1
lines changed

3 files changed

+88
-1
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package com.datadog.appsec.api.security;
2+
3+
import java.util.concurrent.atomic.AtomicLong;
4+
5+
// Number of post-processing tasks (e.g. AppSecRequestContext keep opened)
6+
public class PostProcessingCounter extends AtomicLong {
7+
public static final long MAX_POST_PROCESSING_TASKS = 16;
8+
9+
public boolean tryIncrement() {
10+
while (true) {
11+
long current = this.get();
12+
if (current >= MAX_POST_PROCESSING_TASKS) {
13+
// Do not increment it's already at the maximum
14+
return false;
15+
}
16+
if (this.compareAndSet(current, current + 1)) {
17+
return true;
18+
}
19+
}
20+
}
21+
}

dd-java-agent/appsec/src/main/java/com/datadog/appsec/gateway/GatewayBridge.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import com.datadog.appsec.AppSecSystem;
1616
import com.datadog.appsec.api.security.ApiSecurityRequestSampler;
17+
import com.datadog.appsec.api.security.PostProcessingCounter;
1718
import com.datadog.appsec.config.TraceSegmentPostProcessor;
1819
import com.datadog.appsec.event.EventProducerService;
1920
import com.datadog.appsec.event.EventProducerService.DataSubscriberInfo;
@@ -94,6 +95,8 @@ public class GatewayBridge {
9495

9596
private static final String METASTRUCT_EXPLOIT = "exploit";
9697

98+
private final PostProcessingCounter postProcessingCounter = new PostProcessingCounter();
99+
97100
private final SubscriptionService subscriptionService;
98101
private final EventProducerService producerService;
99102
private final ApiSecurityRequestSampler requestSampler;
@@ -833,9 +836,10 @@ private NoopFlow onRequestEnded(RequestContext ctx_, IGSpanInfo spanInfo) {
833836
if (route instanceof String) {
834837
ctx.setRoute((String) route);
835838
}
836-
if (requestSampler.preSampleRequest(ctx)) {
839+
if (requestSampler.preSampleRequest(ctx) && postProcessingCounter.tryIncrement()) {
837840
// The request is pre-sampled - we need to post-process it
838841
spanInfo.setRequiresPostProcessing(true);
842+
postProcessingCounter.incrementAndGet();
839843
}
840844

841845
return NoopFlow.INSTANCE;
@@ -905,6 +909,8 @@ private void onPostProcessing(RequestContext ctx_) {
905909

906910
maybeExtractSchemas(ctx);
907911
ctx.close();
912+
// Decrease the counter to allow the next request to be post-processed
913+
postProcessingCounter.decrementAndGet();
908914
}
909915

910916
public void stop() {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package com.datadog.appsec.api.security
2+
3+
import spock.lang.Specification
4+
5+
import java.util.concurrent.CountDownLatch
6+
import java.util.concurrent.Executors
7+
8+
class PostProcessingCounterSpec extends Specification {
9+
10+
def "should increment successfully if below the limit"() {
11+
given:
12+
def counter = new PostProcessingCounter()
13+
14+
when:
15+
def result = counter.tryIncrement()
16+
17+
then:
18+
result == true
19+
counter.get() == 1
20+
}
21+
22+
def "should not increment if max limit is reached"() {
23+
given:
24+
def counter = new PostProcessingCounter()
25+
counter.set(PostProcessingCounter.MAX_POST_PROCESSING_TASKS) // Manually setting to max
26+
27+
when:
28+
def result = counter.tryIncrement()
29+
30+
then:
31+
result == false
32+
counter.get() == PostProcessingCounter.MAX_POST_PROCESSING_TASKS // Should remain unchanged
33+
}
34+
35+
def "should handle concurrent increments safely"() {
36+
given:
37+
def counter = new PostProcessingCounter()
38+
def threads = 20
39+
def executor = Executors.newFixedThreadPool(threads)
40+
def latch = new CountDownLatch(threads)
41+
def successes = Collections.synchronizedList([])
42+
43+
when:
44+
(1..threads).each {
45+
executor.submit {
46+
def result = counter.tryIncrement()
47+
if (result) {
48+
successes.add(1)
49+
}
50+
latch.countDown()
51+
}
52+
}
53+
latch.await()
54+
executor.shutdown()
55+
56+
then:
57+
successes.size() <= PostProcessingCounter.MAX_POST_PROCESSING_TASKS // Only max increments should succeed
58+
counter.get() == PostProcessingCounter.MAX_POST_PROCESSING_TASKS // Should be exactly at the limit
59+
}
60+
}

0 commit comments

Comments
 (0)