Skip to content
Open
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import org.opensearch.javaagent.bootstrap.AgentPolicy;

import java.lang.StackWalker.Option;
import java.lang.reflect.Method;
import java.net.InetSocketAddress;
import java.net.NetPermission;
Expand Down Expand Up @@ -46,7 +45,7 @@ public static void intercept(@Advice.AllArguments Object[] args, @Origin Method
return; /* noop */
}

final StackWalker walker = StackWalker.getInstance(Option.RETAIN_CLASS_REFERENCE);
final StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
final Collection<ProtectionDomain> callers = walker.walk(StackCallerProtectionDomainChainExtractor.INSTANCE);

if (args[0] instanceof InetSocketAddress address) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ private StackCallerProtectionDomainChainExtractor() {}
*/
@Override
public Collection<ProtectionDomain> apply(Stream<StackFrame> frames) {
return frames.map(StackFrame::getDeclaringClass)
return frames.takeWhile(
frame -> !(frame.getClassName().equals("java.security.AccessController") && frame.getMethodName().equals("doPrivileged"))
)
.map(StackFrame::getDeclaringClass)
.map(Class::getProtectionDomain)
.filter(pd -> pd.getCodeSource() != null) /* JDK */
.filter(pd -> pd.getCodeSource() != null) // Filter out JDK classes
.collect(Collectors.toSet());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.javaagent;

import org.junit.Test;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.ProtectionDomain;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.hasItem;
import static org.junit.Assert.assertEquals;

public class StackCallerProtectionDomainExtractorTests {

private static List<StackWalker.StackFrame> indirectlyCaptureStackFrames() {
return captureStackFrames();
}

private static List<StackWalker.StackFrame> captureStackFrames() {
// OPTION.RETAIN_CLASS_REFERENCE lets you do f.getDeclaringClass() if you need it
StackWalker walker = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE);
return walker.walk(frames -> frames.collect(Collectors.toList()));
}

@Test
public void testSimpleProtectionDomainExtraction() throws Exception {
StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE;
Set<ProtectionDomain> protectionDomains = (Set<ProtectionDomain>) extractor.apply(captureStackFrames().stream());
assertEquals(7, protectionDomains.size());
List<String> simpleNames = protectionDomains.stream().map(pd -> {
try {
return pd.getCodeSource().getLocation().toURI();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
})
.map(URI::getPath)
.map(Paths::get)
.map(Path::getFileName)
.map(Path::toString)
// strip trailing “-VERSION.jar” if present
.map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", ""))
// otherwise strip “.jar”
.map(name -> name.replaceFirst("\\.jar$", ""))
.toList();
assertThat(
simpleNames,
containsInAnyOrder(
"gradle-worker",
"gradle-worker-main",
"gradle-messaging",
"gradle-testing-base-infrastructure",
"test", // from the build/classes/java/test directory
"junit",
"gradle-testing-jvm-infrastructure"
)
);
}

@Test
public void testIndirectlyCaptureStackFramesInListOfFrames() throws Exception {
List<StackWalker.StackFrame> stackFrames = indirectlyCaptureStackFrames();
List<String> methodNames = stackFrames.stream().map(StackWalker.StackFrame::getMethodName).toList();
assertThat(methodNames, hasItem("indirectlyCaptureStackFrames"));
}

@Test
@SuppressWarnings("removal")
public void testStackTruncationWithAccessController() throws Exception {
AccessController.doPrivileged(new PrivilegedAction<Void>() {
@Override
public Void run() {
StackCallerProtectionDomainChainExtractor extractor = StackCallerProtectionDomainChainExtractor.INSTANCE;
Set<ProtectionDomain> protectionDomains = (Set<ProtectionDomain>) extractor.apply(captureStackFrames().stream());
assertEquals(1, protectionDomains.size());
List<String> simpleNames = protectionDomains.stream().map(pd -> {
try {
return pd.getCodeSource().getLocation().toURI();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
})
.map(URI::getPath)
.map(Paths::get)
.map(Path::getFileName)
.map(Path::toString)
// strip trailing “-VERSION.jar” if present
.map(name -> name.replaceFirst("-\\d[\\d\\.]*\\.jar$", ""))
// otherwise strip “.jar”
.map(name -> name.replaceFirst("\\.jar$", ""))
.toList();
assertThat(
simpleNames,
containsInAnyOrder(
"test" // from the build/classes/java/test directory
)
);
return null;
}
});
}
}