Skip to content

Commit

Permalink
Add support for conditional Transient header propagation (opensearch-…
Browse files Browse the repository at this point in the history
…project#11490)

* Clear transient header from system context

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Clear transient header from system context

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Adds changelog

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Update CHANGELOG.md

Co-authored-by: Andriy Redko <drreta@gmail.com>
Signed-off-by: Gagan Juneja <gagandeepjuneja@gmail.com>

* Adds unit tests

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Refactor code

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Refactor code

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Refactor code

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Supress warning

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

* Refactor code

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>

---------

Signed-off-by: Gagan Juneja <gjjuneja@amazon.com>
Signed-off-by: Gagan Juneja <gagandeepjuneja@gmail.com>
Co-authored-by: Gagan Juneja <gjjuneja@amazon.com>
Co-authored-by: Andriy Redko <drreta@gmail.com>
Signed-off-by: Shivansh Arora <hishiv@amazon.com>
  • Loading branch information
3 people authored and shiv0408 committed Apr 25, 2024
1 parent de71cae commit 8aec65f
Show file tree
Hide file tree
Showing 9 changed files with 193 additions and 28 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix template setting override for replication type ([#11417](https://github.com/opensearch-project/OpenSearch/pull/11417))
- Fix Automatic addition of protocol broken in #11512 ([#11609](https://github.com/opensearch-project/OpenSearch/pull/11609))
- Fix issue when calling Delete PIT endpoint and no PITs exist ([#11711](https://github.com/opensearch-project/OpenSearch/pull/11711))
- Fix tracing context propagation for local transport instrumentation ([#11490](https://github.com/opensearch-project/OpenSearch/pull/11490))

### Security

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public StoredContext stashContext() {
);
}

final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(context.transientHeaders, context.isSystemContext);
if (!transientHeaders.isEmpty()) {
threadContextStruct = threadContextStruct.putTransient(transientHeaders);
}
Expand All @@ -182,7 +182,7 @@ public StoredContext stashContext() {
public Writeable captureAsWriteable() {
final ThreadContextStruct context = threadLocal.get();
return out -> {
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
};
}
Expand Down Expand Up @@ -245,7 +245,7 @@ public StoredContext newStoredContext(boolean preserveResponseHeaders, Collectio
final Map<String, Object> newTransientHeaders = new HashMap<>(originalContext.transientHeaders);

boolean transientHeadersModified = false;
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders);
final Map<String, Object> transientHeaders = propagateTransients(originalContext.transientHeaders, originalContext.isSystemContext);
if (!transientHeaders.isEmpty()) {
newTransientHeaders.putAll(transientHeaders);
transientHeadersModified = true;
Expand Down Expand Up @@ -322,7 +322,7 @@ public Supplier<StoredContext> wrapRestorable(StoredContext storedContext) {
@Override
public void writeTo(StreamOutput out) throws IOException {
final ThreadContextStruct context = threadLocal.get();
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders);
final Map<String, String> propagatedHeaders = propagateHeaders(context.transientHeaders, context.isSystemContext);
context.writeTo(out, defaultHeader, propagatedHeaders);
}

Expand Down Expand Up @@ -534,7 +534,7 @@ boolean isDefaultContext() {
* by the system itself rather than by a user action.
*/
public void markAsSystemContext() {
threadLocal.set(threadLocal.get().setSystemContext());
threadLocal.set(threadLocal.get().setSystemContext(propagators));
}

/**
Expand Down Expand Up @@ -573,15 +573,15 @@ public static Map<String, String> buildDefaultHeaders(Settings settings) {
}
}

private Map<String, Object> propagateTransients(Map<String, Object> source) {
private Map<String, Object> propagateTransients(Map<String, Object> source, boolean isSystemContext) {
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(source)));
propagators.forEach(p -> transients.putAll(p.transients(source, isSystemContext)));
return transients;
}

private Map<String, String> propagateHeaders(Map<String, Object> source) {
private Map<String, String> propagateHeaders(Map<String, Object> source, boolean isSystemContext) {
final Map<String, String> headers = new HashMap<>();
propagators.forEach(p -> headers.putAll(p.headers(source)));
propagators.forEach(p -> headers.putAll(p.headers(source, isSystemContext)));
return headers;
}

Expand All @@ -603,11 +603,13 @@ private static final class ThreadContextStruct {
// saving current warning headers' size not to recalculate the size with every new warning header
private final long warningHeadersSize;

private ThreadContextStruct setSystemContext() {
private ThreadContextStruct setSystemContext(final List<ThreadContextStatePropagator> propagators) {
if (isSystemContext) {
return this;
}
return new ThreadContextStruct(requestHeaders, responseHeaders, transientHeaders, persistentHeaders, true);
final Map<String, Object> transients = new HashMap<>();
propagators.forEach(p -> transients.putAll(p.transients(transientHeaders, true)));
return new ThreadContextStruct(requestHeaders, responseHeaders, transients, persistentHeaders, true);
}

private ThreadContextStruct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,41 @@
public interface ThreadContextStatePropagator {
/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
* @param source current context transient headers
*
* @param source current context transient headers
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
@Deprecated(since = "2.12.0", forRemoval = true)
Map<String, Object> transients(Map<String, Object> source);

/**
* Returns the list of transient headers that needs to be propagated from current context to new thread context.
*
* @param source current context transient headers
* @param isSystemContext if the propagation is for system context.
* @return the list of transient headers that needs to be propagated from current context to new thread context
*/
default Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
return transients(source);
};

/**
* Returns the list of request headers that needs to be propagated from current context to request.
* @param source current context headers
*
* @param source current context headers
* @return the list of request headers that needs to be propagated from current context to request
*/
@Deprecated(since = "2.12.0", forRemoval = true)
Map<String, String> headers(Map<String, Object> source);

/**
* Returns the list of request headers that needs to be propagated from current context to request.
*
* @param source current context headers
* @param isSystemContext if the propagation is for system context.
* @return the list of request headers that needs to be propagated from current context to request
*/
default Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
* Propagates TASK_ID across thread contexts
*/
public class TaskThreadContextStatePropagator implements ThreadContextStatePropagator {

@Override
@SuppressWarnings("removal")
public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();

Expand All @@ -32,7 +34,18 @@ public Map<String, Object> transients(Map<String, Object> source) {
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
return transients(source);
}

@Override
@SuppressWarnings("removal")
public Map<String, String> headers(Map<String, Object> source) {
return Collections.emptyMap();
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.concurrent.ThreadContextStatePropagator;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -50,20 +51,29 @@ public void put(String key, Span span) {
}

@Override
@SuppressWarnings("removal")
public Map<String, Object> transients(Map<String, Object> source) {
final Map<String, Object> transients = new HashMap<>();

if (source.containsKey(CURRENT_SPAN)) {
final SpanReference current = (SpanReference) source.get(CURRENT_SPAN);
if (current != null) {
transients.put(CURRENT_SPAN, new SpanReference(current.getSpan()));
}
}

return transients;
}

@Override
public Map<String, Object> transients(Map<String, Object> source, boolean isSystemContext) {
if (isSystemContext == true) {
return Collections.emptyMap();
} else {
return transients(source);
}
}

@Override
@SuppressWarnings("removal")
public Map<String, String> headers(Map<String, Object> source) {
final Map<String, String> headers = new HashMap<>();

Expand All @@ -77,6 +87,11 @@ public Map<String, String> headers(Map<String, Object> source) {
return headers;
}

@Override
public Map<String, String> headers(Map<String, Object> source, boolean isSystemContext) {
return headers(source);
}

Span getCurrentSpan(String key) {
SpanReference currentSpanRef = threadContext.getTransient(key);
return (currentSpanRef == null) ? null : currentSpanRef.getSpan();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -868,19 +868,10 @@ public final <T extends TransportResponse> void sendRequest(
final TransportRequestOptions options,
final TransportResponseHandler<T> handler
) {
if (connection == localNodeConnection) {
// See please https://github.com/opensearch-project/OpenSearch/issues/10291
sendRequestAsync(connection, action, request, options, handler);
} else {
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(
handler,
span,
tracer
);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
final Span span = tracer.startSpan(SpanBuilder.from(action, connection));
try (SpanScope spanScope = tracer.withSpanInScope(span)) {
TransportResponseHandler<T> traceableTransportResponseHandler = TraceableTransportResponseHandler.create(handler, span, tracer);
sendRequestAsync(connection, action, request, options, traceableTransportResponseHandler);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import java.util.Map;
import java.util.function.Supplier;

import org.mockito.Mockito;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasItem;
Expand Down Expand Up @@ -740,6 +742,71 @@ public void testMarkAsSystemContext() throws IOException {
assertFalse(threadContext.isSystemContext());
}

public void testSystemContextWithPropagator() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
ThreadContext threadContext = new ThreadContext(build);
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);

Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
threadContext.registerThreadContextStatePropagator(mockPropagator);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", 1);
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("bar", threadContext.getHeader("foo"));
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
threadContext.markAsSystemContext();
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

assertEquals("bar", threadContext.getHeader("foo"));
assertEquals(Integer.valueOf(1), threadContext.getTransient("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testSerializeSystemContext() throws IOException {
Settings build = Settings.builder().put("request.headers.default", "1").build();
Map<String, Object> transientHeaderMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> transientHeaderTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, Object> headerMap = Collections.singletonMap("test_transient_propagation_key", "test");
Map<String, String> headerTransformedMap = Collections.singletonMap("test_transient_propagation_key", "test");
ThreadContext threadContext = new ThreadContext(build);
ThreadContextStatePropagator mockPropagator = Mockito.mock(ThreadContextStatePropagator.class);
Mockito.when(mockPropagator.transients(transientHeaderMap, true)).thenReturn(Collections.emptyMap());
Mockito.when(mockPropagator.transients(transientHeaderMap, false)).thenReturn(transientHeaderTransformedMap);

Mockito.when(mockPropagator.headers(headerMap, true)).thenReturn(headerTransformedMap);
Mockito.when(mockPropagator.headers(headerMap, false)).thenReturn(headerTransformedMap);
threadContext.registerThreadContextStatePropagator(mockPropagator);
threadContext.putHeader("foo", "bar");
threadContext.putTransient("test_transient_propagation_key", "test");
BytesStreamOutput out = new BytesStreamOutput();
BytesStreamOutput outFromSystemContext = new BytesStreamOutput();
threadContext.writeTo(out);
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.markAsSystemContext();
threadContext.writeTo(outFromSystemContext);
assertNull(threadContext.getHeader("foo"));
assertNull(threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(outFromSystemContext.bytes().streamInput());
assertNull(threadContext.getHeader("test_transient_propagation_key"));
}
assertEquals("test", threadContext.getTransient("test_transient_propagation_key"));
threadContext.readHeaders(out.bytes().streamInput());
assertEquals("bar", threadContext.getHeader("foo"));
assertEquals("test", threadContext.getHeader("test_transient_propagation_key"));
assertEquals("1", threadContext.getHeader("default"));
}

public void testPutHeaders() {
Settings build = Settings.builder().put("request.headers.default", "1").build();
ThreadContext threadContext = new ThreadContext(build);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.tasks;

import org.opensearch.test.OpenSearchTestCase;

import java.util.HashMap;
import java.util.Map;

import static org.opensearch.tasks.TaskResourceTrackingService.TASK_ID;

public class TaskThreadContextStatePropagatorTests extends OpenSearchTestCase {
private final TaskThreadContextStatePropagator taskThreadContextStatePropagator = new TaskThreadContextStatePropagator();

public void testTransient() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, false);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}

public void testTransientForSystemContext() {
Map<String, Object> transientHeader = new HashMap<>();
transientHeader.put(TASK_ID, "t_1");
Map<String, Object> transientPropagatedHeader = taskThreadContextStatePropagator.transients(transientHeader, true);
assertEquals("t_1", transientPropagatedHeader.get(TASK_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,20 @@ public void run() {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}

public void testSpanNotPropagatedToChildSystemThreadContext() {
final Span span = tracer.startSpan(SpanCreationContext.internal().name("test"));

try (SpanScope scope = tracer.withSpanInScope(span)) {
try (StoredContext ignored = threadContext.stashContext()) {
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(span));
threadContext.markAsSystemContext();
assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

assertThat(threadContext.getTransient(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(not(nullValue())));
assertThat(threadContextStorage.get(ThreadContextBasedTracerContextStorage.CURRENT_SPAN), is(nullValue()));
}
}

0 comments on commit 8aec65f

Please sign in to comment.