Skip to content

Commit

Permalink
Declarative ordering of @Rule/@ClassRule rules
Browse files Browse the repository at this point in the history
  • Loading branch information
panchenko authored and kcooney committed Aug 7, 2017
1 parent bb48ff9 commit aad22b8
Show file tree
Hide file tree
Showing 16 changed files with 554 additions and 61 deletions.
8 changes: 8 additions & 0 deletions src/main/java/org/junit/ClassRule.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,12 @@
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.METHOD})
public @interface ClassRule {

/**
* Specifies the order in which rules are applied. The rules with a higher value are inner.
*
* @since 4.13
*/
int order() default Rule.DEFAULT_ORDER;

}
9 changes: 9 additions & 0 deletions src/main/java/org/junit/Rule.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,13 @@
@Target({ElementType.FIELD, ElementType.METHOD})
public @interface Rule {

int DEFAULT_ORDER = -1;

/**
* Specifies the order in which rules are applied. The rules with a higher value are inner.
*
* @since 4.13
*/
int order() default DEFAULT_ORDER;

}
2 changes: 1 addition & 1 deletion src/main/java/org/junit/internal/Throwables.java
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private static boolean isTestFrameworkMethod(String methodName) {
"java.lang.reflect.",
"org.junit.rules.RunRules.<init>(",
"org.junit.rules.RunRules.applyAll(", // calls TestRules
"org.junit.runners.BlockJUnit4ClassRunner.withMethodRules(", // calls MethodRules
"org.junit.runners.RuleContainer.apply(", // calls MethodRules & TestRules
"junit.framework.TestCase.runBare(", // runBare() directly calls setUp() and tearDown()
};

Expand Down
96 changes: 46 additions & 50 deletions src/main/java/org/junit/runners/BlockJUnit4ClassRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static org.junit.internal.runners.rules.RuleMemberValidator.RULE_METHOD_VALIDATOR;
import static org.junit.internal.runners.rules.RuleMemberValidator.RULE_VALIDATOR;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand All @@ -22,12 +23,13 @@
import org.junit.internal.runners.statements.RunAfters;
import org.junit.internal.runners.statements.RunBefores;
import org.junit.rules.MethodRule;
import org.junit.rules.RunRules;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.model.FrameworkMember;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.MemberValueConsumer;
import org.junit.runners.model.MultipleFailureException;
import org.junit.runners.model.Statement;
import org.junit.runners.model.TestClass;
Expand Down Expand Up @@ -390,29 +392,23 @@ protected Statement withAfters(FrameworkMethod method, Object target,
target);
}

private Statement withRules(FrameworkMethod method, Object target,
Statement statement) {
List<TestRule> testRules = getTestRules(target);
Statement result = statement;
result = withMethodRules(method, testRules, target, result);
result = withTestRules(method, testRules, result);

return result;
}

private Statement withMethodRules(FrameworkMethod method, List<TestRule> testRules,
Object target, Statement result) {
Statement withMethodRules = result;
for (org.junit.rules.MethodRule each : getMethodRules(target)) {
if (!(each instanceof TestRule && testRules.contains(each))) {
withMethodRules = each.apply(withMethodRules, method, target);
private Statement withRules(FrameworkMethod method, Object target, Statement statement) {
RuleContainer ruleContainer = new RuleContainer();
CURRENT_RULE_CONTAINER.set(ruleContainer);
try {
List<TestRule> testRules = getTestRules(target);
for (MethodRule each : rules(target)) {
if (!(each instanceof TestRule && testRules.contains(each))) {
ruleContainer.add(each);
}
}
for (TestRule rule : testRules) {
ruleContainer.add(rule);
}
} finally {
CURRENT_RULE_CONTAINER.remove();
}
return withMethodRules;
}

private List<org.junit.rules.MethodRule> getMethodRules(Object target) {
return rules(target);
return ruleContainer.apply(method, describeChild(method), target, statement);
}

/**
Expand All @@ -421,27 +417,12 @@ private List<org.junit.rules.MethodRule> getMethodRules(Object target) {
* test
*/
protected List<MethodRule> rules(Object target) {
List<MethodRule> rules = getTestClass().getAnnotatedMethodValues(target,
Rule.class, MethodRule.class);

rules.addAll(getTestClass().getAnnotatedFieldValues(target,
Rule.class, MethodRule.class));

return rules;
}

/**
* Returns a {@link Statement}: apply all non-static fields
* annotated with {@link Rule}.
*
* @param statement The base statement
* @return a RunRules statement if any class-level {@link Rule}s are
* found, or the base statement
*/
private Statement withTestRules(FrameworkMethod method, List<TestRule> testRules,
Statement statement) {
return testRules.isEmpty() ? statement :
new RunRules(statement, testRules, describeChild(method));
RuleCollector<MethodRule> collector = new RuleCollector<MethodRule>();
getTestClass().collectAnnotatedMethodValues(target, Rule.class, MethodRule.class,
collector);
getTestClass().collectAnnotatedFieldValues(target, Rule.class, MethodRule.class,
collector);
return collector.result;
}

/**
Expand All @@ -450,13 +431,10 @@ private Statement withTestRules(FrameworkMethod method, List<TestRule> testRules
* test
*/
protected List<TestRule> getTestRules(Object target) {
List<TestRule> result = getTestClass().getAnnotatedMethodValues(target,
Rule.class, TestRule.class);

result.addAll(getTestClass().getAnnotatedFieldValues(target,
Rule.class, TestRule.class));

return result;
RuleCollector<TestRule> collector = new RuleCollector<TestRule>();
getTestClass().collectAnnotatedMethodValues(target, Rule.class, TestRule.class, collector);
getTestClass().collectAnnotatedFieldValues(target, Rule.class, TestRule.class, collector);
return collector.result;
}

private Class<? extends Throwable> getExpectedException(Test annotation) {
Expand All @@ -473,4 +451,22 @@ private long getTimeout(Test annotation) {
}
return annotation.timeout();
}

private static final ThreadLocal<RuleContainer> CURRENT_RULE_CONTAINER =
new ThreadLocal<RuleContainer>();

private static class RuleCollector<T> implements MemberValueConsumer<T> {
final List<T> result = new ArrayList<T>();

public void accept(FrameworkMember member, T value) {
Rule rule = member.getAnnotation(Rule.class);
if (rule != null) {
RuleContainer container = CURRENT_RULE_CONTAINER.get();
if (container != null) {
container.setOrder(value, rule.order());
}
}
result.add(value);
}
}
}
31 changes: 28 additions & 3 deletions src/main/java/org/junit/runners/ParentRunner.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
import org.junit.runner.manipulation.Sorter;
import org.junit.runner.notification.RunNotifier;
import org.junit.runner.notification.StoppedByUserException;
import org.junit.runners.model.FrameworkMember;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InitializationError;
import org.junit.runners.model.InvalidTestClassError;
import org.junit.runners.model.MemberValueConsumer;
import org.junit.runners.model.RunnerScheduler;
import org.junit.runners.model.Statement;
import org.junit.runners.model.TestClass;
Expand Down Expand Up @@ -269,9 +271,10 @@ private Statement withClassRules(Statement statement) {
* each method in the tested class.
*/
protected List<TestRule> classRules() {
List<TestRule> result = testClass.getAnnotatedMethodValues(null, ClassRule.class, TestRule.class);
result.addAll(testClass.getAnnotatedFieldValues(null, ClassRule.class, TestRule.class));
return result;
ClassRuleCollector collector = new ClassRuleCollector();
testClass.collectAnnotatedMethodValues(null, ClassRule.class, TestRule.class, collector);
testClass.collectAnnotatedFieldValues(null, ClassRule.class, TestRule.class, collector);
return collector.getOrderedRules();
}

/**
Expand Down Expand Up @@ -487,4 +490,26 @@ public int compare(T o1, T o2) {
public void setScheduler(RunnerScheduler scheduler) {
this.scheduler = scheduler;
}

private static class ClassRuleCollector implements MemberValueConsumer<TestRule> {
final List<RuleContainer.RuleEntry> entries = new ArrayList<RuleContainer.RuleEntry>();

public void accept(FrameworkMember member, TestRule value) {
ClassRule rule = member.getAnnotation(ClassRule.class);
entries.add(new RuleContainer.RuleEntry(value, RuleContainer.RuleEntry.TYPE_TEST_RULE,
rule != null ? rule.order() : null));
}

public List<TestRule> getOrderedRules() {
if (entries.isEmpty()) {
return Collections.emptyList();
}
Collections.sort(entries, RuleContainer.ENTRY_COMPARATOR);
List<TestRule> result = new ArrayList<TestRule>(entries.size());
for (RuleContainer.RuleEntry entry : entries) {
result.add((TestRule) entry.rule);
}
return result;
}
}
}
113 changes: 113 additions & 0 deletions src/main/java/org/junit/runners/RuleContainer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package org.junit.runners;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.List;

import org.junit.Rule;
import org.junit.rules.MethodRule;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.Statement;

/**
* Data structure for ordering of {@link TestRule}/{@link MethodRule} instances.
*
* @since 4.13
*/
class RuleContainer {
private final IdentityHashMap<Object, Integer> orderValues = new IdentityHashMap<Object, Integer>();
private final List<TestRule> testRules = new ArrayList<TestRule>();
private final List<MethodRule> methodRules = new ArrayList<MethodRule>();

/**
* Sets order value for the specified rule.
*/
public void setOrder(Object rule, int order) {
orderValues.put(rule, order);
}

public void add(MethodRule methodRule) {
methodRules.add(methodRule);
}

public void add(TestRule testRule) {
testRules.add(testRule);
}

static final Comparator<RuleEntry> ENTRY_COMPARATOR = new Comparator<RuleEntry>() {
public int compare(RuleEntry o1, RuleEntry o2) {
int result = compareInt(o1.order, o2.order);
return result != 0 ? result : o1.type - o2.type;
}

private int compareInt(int a, int b) {
return (a < b) ? 1 : (a == b ? 0 : -1);
}
};

/**
* Returns entries in the order how they should be applied, i.e. inner-to-outer.
*/
private List<RuleEntry> getSortedEntries() {
List<RuleEntry> ruleEntries = new ArrayList<RuleEntry>(
methodRules.size() + testRules.size());
for (MethodRule rule : methodRules) {
ruleEntries.add(new RuleEntry(rule, RuleEntry.TYPE_METHOD_RULE, orderValues.get(rule)));
}
for (TestRule rule : testRules) {
ruleEntries.add(new RuleEntry(rule, RuleEntry.TYPE_TEST_RULE, orderValues.get(rule)));
}
Collections.sort(ruleEntries, ENTRY_COMPARATOR);
return ruleEntries;
}

/**
* Applies all the rules ordered accordingly to the specified {@code statement}.
*/
public Statement apply(FrameworkMethod method, Description description, Object target,
Statement statement) {
if (methodRules.isEmpty() && testRules.isEmpty()) {
return statement;
}
Statement result = statement;
for (RuleEntry ruleEntry : getSortedEntries()) {
if (ruleEntry.type == RuleEntry.TYPE_TEST_RULE) {
result = ((TestRule) ruleEntry.rule).apply(result, description);
} else {
result = ((MethodRule) ruleEntry.rule).apply(result, method, target);
}
}
return result;
}

/**
* Returns rule instances in the order how they should be applied, i.e. inner-to-outer.
* VisibleForTesting
*/
List<Object> getSortedRules() {
List<Object> result = new ArrayList<Object>();
for (RuleEntry entry : getSortedEntries()) {
result.add(entry.rule);
}
return result;
}

static class RuleEntry {
static final int TYPE_TEST_RULE = 1;
static final int TYPE_METHOD_RULE = 0;

final Object rule;
final int type;
final int order;

RuleEntry(Object rule, int type, Integer order) {
this.rule = rule;
this.type = type;
this.order = order != null ? order.intValue() : Rule.DEFAULT_ORDER;
}
}
}
18 changes: 18 additions & 0 deletions src/main/java/org/junit/runners/model/MemberValueConsumer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.junit.runners.model;

/**
* Represents a receiver for values of annotated fields/methods together with the declaring member.
*
* @see TestClass#collectAnnotatedFieldValues(Object, Class, Class, MemberValueConsumer)
* @see TestClass#collectAnnotatedMethodValues(Object, Class, Class, MemberValueConsumer)
* @since 4.13
*/
public interface MemberValueConsumer<T> {
/**
* Receives the next value and its declaring member.
*
* @param member declaring member ({@link FrameworkMethod or {@link FrameworkField}}
* @param value the value of the next member
*/
void accept(FrameworkMember member, T value);
}
Loading

0 comments on commit aad22b8

Please sign in to comment.