Skip to content

Commit

Permalink
Prevent infinite recursion in rootCause() condition
Browse files Browse the repository at this point in the history
See #3839
  • Loading branch information
sbrannen committed Jun 2, 2024
1 parent 71c3b05 commit 6c5a9a0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import static org.apiguardian.api.API.Status.MAINTAINED;
import static org.junit.platform.commons.util.FunctionUtils.where;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Predicate;
Expand Down Expand Up @@ -166,8 +167,9 @@ private static Condition<Throwable> cause(Condition<Throwable> condition) {
private static Condition<Throwable> rootCause(Condition<Throwable> condition) {
Predicate<Throwable> predicate = throwable -> {
Preconditions.notNull(throwable, "Throwable must not be null");
Throwable cause = Preconditions.notNull(throwable.getCause(), "Throwable does not have a cause");
return condition.matches(getRootCause(cause));
Preconditions.notNull(throwable.getCause(), "Throwable does not have a cause");
Throwable rootCause = getRootCause(throwable, new ArrayList<>());
return condition.matches(rootCause);
};
return new Condition<>(predicate, "throwable root cause matches %s", condition);
}
Expand All @@ -176,12 +178,20 @@ private static Condition<Throwable> rootCause(Condition<Throwable> condition) {
* Get the root cause of the supplied {@link Throwable}, or the supplied
* {@link Throwable} if it has no cause.
*/
private static Throwable getRootCause(Throwable throwable) {
private static Throwable getRootCause(Throwable throwable, List<Throwable> causeChain) {
// If we have already seen the current Throwable, that means we have
// encountered recursion in the cause chain and therefore return the last
// Throwable in the cause chain, which was the root cause before the recursion.
if (causeChain.contains(throwable)) {
return causeChain.get(causeChain.size() - 1);
}
Throwable cause = throwable.getCause();
if (cause == null) {
return throwable;
}
return getRootCause(cause);
// Track current Throwable before recursing.
causeChain.add(throwable);
return getRootCause(cause, causeChain);
}

private static Condition<Throwable> suppressed(int index, Condition<Throwable> condition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,42 @@ void rootCauseDoesNotMatchForRootCauseWithDifferentMessage() {
assertThat(rootCauseCondition.matches(throwable)).isFalse();
}

@Test
void rootCauseMatchesForRootCauseWithExpectedMessageAndSingleLevelRecursiveCauseChain() {
RuntimeException rootCause = new RuntimeException(EXPECTED);
Throwable throwable = new Throwable(rootCause);
rootCause.initCause(throwable);

assertThat(rootCauseCondition.matches(throwable)).isTrue();
}

@Test
void rootCauseDoesNotMatchForRootCauseWithDifferentMessageAndSingleLevelRecursiveCauseChain() {
RuntimeException rootCause = new RuntimeException(UNEXPECTED);
Throwable throwable = new Throwable(rootCause);
rootCause.initCause(throwable);

assertThat(rootCauseCondition.matches(throwable)).isFalse();
}

@Test
void rootCauseMatchesForRootCauseWithExpectedMessageAndDoubleLevelRecursiveCauseChain() {
RuntimeException rootCause = new RuntimeException(EXPECTED);
Exception intermediateCause = new Exception("intermediate cause", rootCause);
Throwable throwable = new Throwable(intermediateCause);
rootCause.initCause(throwable);

assertThat(rootCauseCondition.matches(throwable)).isTrue();
}

@Test
void rootCauseDoesNotMatchForRootCauseWithDifferentMessageAndDoubleLevelRecursiveCauseChain() {
RuntimeException rootCause = new RuntimeException(UNEXPECTED);
Exception intermediateCause = new Exception("intermediate cause", rootCause);
Throwable throwable = new Throwable(intermediateCause);
rootCause.initCause(throwable);

assertThat(rootCauseCondition.matches(throwable)).isFalse();
}

}

0 comments on commit 6c5a9a0

Please sign in to comment.