Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@BeforeParam/@AfterParam for Parameterized runner #1435

Merged
merged 18 commits into from
Apr 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@ public void evaluate() throws Throwable {
} finally {
for (FrameworkMethod each : afters) {
try {
each.invokeExplosively(target);
invokeMethod(each);
} catch (Throwable e) {
errors.add(e);
}
}
}
MultipleFailureException.assertEmpty(errors);
}

/**
* @since 4.13
*/
protected void invokeMethod(FrameworkMethod method) throws Throwable {
method.invokeExplosively(target);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ public RunBefores(Statement next, List<FrameworkMethod> befores, Object target)
@Override
public void evaluate() throws Throwable {
for (FrameworkMethod before : befores) {
before.invokeExplosively(target);
invokeMethod(before);
}
next.evaluate();
}

/**
* @since 4.13
*/
protected void invokeMethod(FrameworkMethod method) throws Throwable {
method.invokeExplosively(target);
}
}
132 changes: 110 additions & 22 deletions src/main/java/org/junit/runners/Parameterized.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.junit.runners;

import java.lang.annotation.Annotation;
import java.lang.annotation.ElementType;
import java.lang.annotation.Inherited;
import java.lang.annotation.Retention;
Expand All @@ -8,11 +9,13 @@
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;

import org.junit.runner.Runner;
import org.junit.runners.model.FrameworkMethod;
import org.junit.runners.model.InvalidTestClassError;
import org.junit.runners.model.TestClass;
import org.junit.runners.parameterized.BlockJUnit4ClassRunnerWithParametersFactory;
import org.junit.runners.parameterized.ParametersRunnerFactory;
Expand Down Expand Up @@ -134,6 +137,19 @@
* }
* </pre>
*
* <h3>Executing code before/after executing tests for specific parameters</h3>
* <p>
* If your test needs to perform some preparation or cleanup based on the
* parameters, this can be done by adding public static methods annotated with
* {@code @BeforeParam}/{@code @AfterParam}. Such methods should either have no
* parameters or the same parameters as the test.
* <pre>
* &#064;BeforeParam
* public static void beforeTestsForParameter(String onlyParameter) {
* System.out.println("Testing " + onlyParameter);
* }
* </pre>
*
* <h3>Create different runners</h3>
* <p>
* By default the {@code Parameterized} runner creates a slightly modified
Expand Down Expand Up @@ -234,32 +250,91 @@ public class Parameterized extends Suite {
Class<? extends ParametersRunnerFactory> value() default BlockJUnit4ClassRunnerWithParametersFactory.class;
}

/**
* Annotation for {@code public static void} methods which should be executed before
* evaluating tests with particular parameters.
*
* @see org.junit.BeforeClass
* @see org.junit.Before
* @since 4.13
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface BeforeParam {
}

/**
* Annotation for {@code public static void} methods which should be executed after
* evaluating tests with particular parameters.
*
* @see org.junit.AfterClass
* @see org.junit.After
* @since 4.13
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface AfterParam {
}

/**
* Only called reflectively. Do not use programmatically.
*/
public Parameterized(Class<?> klass) throws Throwable {
super(klass, RunnersFactory.createRunnersForClass(klass));
this(klass, new RunnersFactory(klass));
}

private Parameterized(Class<?> klass, RunnersFactory runnersFactory) throws Exception {
super(klass, runnersFactory.createRunners());
validateBeforeParamAndAfterParamMethods(runnersFactory.parameterCount);
}

private void validateBeforeParamAndAfterParamMethods(Integer parameterCount)
throws InvalidTestClassError {
List<Throwable> errors = new ArrayList<Throwable>();
validatePublicStaticVoidMethods(Parameterized.BeforeParam.class, parameterCount, errors);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we delete validateBeforeParamAndAfterParamMethods() and instead override collectInitializationErrors(List<Throwable>) and call validatePublicStaticVoidMethods there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that before, collectInitializationErrors() is called from the super class constructor and the number of method parameters can't be easily validated (only with some tricks, e.g. using ThreadLocal).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I've never been a fan of calling non-static methods in constructors. It's caused no end of problems in JUnit4. Hopefully they avoided that in the JUnit5 code base

validatePublicStaticVoidMethods(Parameterized.AfterParam.class, parameterCount, errors);
if (!errors.isEmpty()) {
throw new InvalidTestClassError(getTestClass().getJavaClass(), errors);
}
}

private void validatePublicStaticVoidMethods(
Class<? extends Annotation> annotation, Integer parameterCount,
List<Throwable> errors) {
List<FrameworkMethod> methods = getTestClass().getAnnotatedMethods(annotation);
for (FrameworkMethod fm : methods) {
fm.validatePublicVoid(true, errors);
if (parameterCount != null) {
int methodParameterCount = fm.getMethod().getParameterTypes().length;
if (methodParameterCount != 0 && methodParameterCount != parameterCount) {
errors.add(new Exception("Method " + fm.getName()
+ "() should have 0 or " + parameterCount + " parameter(s)"));
}
}
}
}

private static class RunnersFactory {
private static final ParametersRunnerFactory DEFAULT_FACTORY = new BlockJUnit4ClassRunnerWithParametersFactory();

private final TestClass testClass;
private final FrameworkMethod parametersMethod;
private final List<Object> allParameters;
private final int parameterCount;

static List<Runner> createRunnersForClass(Class<?> klass)
throws Throwable {
return new RunnersFactory(klass).createRunners();
}

private RunnersFactory(Class<?> klass) {
private RunnersFactory(Class<?> klass) throws Throwable {
testClass = new TestClass(klass);
parametersMethod = getParametersMethod(testClass);
allParameters = allParameters(testClass, parametersMethod);
parameterCount =
allParameters.isEmpty() ? 0 : normalizeParameters(allParameters.get(0)).length;
}

private List<Runner> createRunners() throws Throwable {
Parameters parameters = getParametersMethod().getAnnotation(
Parameters.class);
private List<Runner> createRunners() throws Exception {
Parameters parameters = parametersMethod.getAnnotation(Parameters.class);
return Collections.unmodifiableList(createRunnersForParameters(
allParameters(), parameters.name(),
allParameters, parameters.name(),
getParametersRunnerFactory()));
}

Expand All @@ -278,25 +353,37 @@ private ParametersRunnerFactory getParametersRunnerFactory()

private TestWithParameters createTestWithNotNormalizedParameters(
String pattern, int index, Object parametersOrSingleParameter) {
Object[] parameters = (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
Object[] parameters = normalizeParameters(parametersOrSingleParameter);
return createTestWithParameters(testClass, pattern, index, parameters);
}

private static Object[] normalizeParameters(Object parametersOrSingleParameter) {
return (parametersOrSingleParameter instanceof Object[]) ? (Object[]) parametersOrSingleParameter
: new Object[] { parametersOrSingleParameter };
return createTestWithParameters(testClass, pattern, index,
parameters);
}

@SuppressWarnings("unchecked")
private Iterable<Object> allParameters() throws Throwable {
Object parameters = getParametersMethod().invokeExplosively(null);
if (parameters instanceof Iterable) {
return (Iterable<Object>) parameters;
private static List<Object> allParameters(
TestClass testClass, FrameworkMethod parametersMethod) throws Throwable {
Object parameters = parametersMethod.invokeExplosively(null);
if (parameters instanceof List) {
return (List<Object>) parameters;
} else if (parameters instanceof Collection) {
return new ArrayList<Object>((Collection<Object>) parameters);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kcooney IMHO there is no need to handle Collection specially, it can be treated as Iterable.

} else if (parameters instanceof Iterable) {
List<Object> result = new ArrayList<Object>();
for (Object entry : ((Iterable<Object>) parameters)) {
result.add(entry);
}
return result;
} else if (parameters instanceof Object[]) {
return Arrays.asList((Object[]) parameters);
} else {
throw parametersMethodReturnedWrongType();
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
}
}

private FrameworkMethod getParametersMethod() throws Exception {
private static FrameworkMethod getParametersMethod(TestClass testClass) throws Exception {
List<FrameworkMethod> methods = testClass
.getAnnotatedMethods(Parameters.class);
for (FrameworkMethod each : methods) {
Expand All @@ -322,7 +409,7 @@ private List<Runner> createRunnersForParameters(
}
return runners;
} catch (ClassCastException e) {
throw parametersMethodReturnedWrongType();
throw parametersMethodReturnedWrongType(testClass, parametersMethod);
}
}

Expand All @@ -338,9 +425,10 @@ private List<TestWithParameters> createTestsForParameters(
return children;
}

private Exception parametersMethodReturnedWrongType() throws Exception {
private static Exception parametersMethodReturnedWrongType(
TestClass testClass, FrameworkMethod parametersMethod) throws Exception {
String className = testClass.getName();
String methodName = getParametersMethod().getName();
String methodName = parametersMethod.getName();
String message = MessageFormat.format(
"{0}.{1}() must return an Iterable of arrays.", className,
methodName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
import java.lang.reflect.Field;
import java.util.List;

import org.junit.internal.runners.statements.RunAfters;
import org.junit.internal.runners.statements.RunBefores;
import org.junit.runner.RunWith;
import org.junit.runner.notification.RunNotifier;
import org.junit.runners.BlockJUnit4ClassRunner;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.model.FrameworkField;
import org.junit.runners.model.FrameworkMethod;
Expand Down Expand Up @@ -135,7 +138,46 @@ protected void validateFields(List<Throwable> errors) {

@Override
protected Statement classBlock(RunNotifier notifier) {
return childrenInvoker(notifier);
Statement statement = childrenInvoker(notifier);
statement = withBeforeParams(statement);
statement = withAfterParams(statement);
return statement;
}

private Statement withBeforeParams(Statement statement) {
List<FrameworkMethod> befores = getTestClass()
.getAnnotatedMethods(Parameterized.BeforeParam.class);
return befores.isEmpty() ? statement : new RunBeforeParams(statement, befores);
}

private class RunBeforeParams extends RunBefores {
RunBeforeParams(Statement next, List<FrameworkMethod> befores) {
super(next, befores, null);
}

@Override
protected void invokeMethod(FrameworkMethod method) throws Throwable {
int paramCount = method.getMethod().getParameterTypes().length;
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
}
}

private Statement withAfterParams(Statement statement) {
List<FrameworkMethod> afters = getTestClass()
.getAnnotatedMethods(Parameterized.AfterParam.class);
return afters.isEmpty() ? statement : new RunAfterParams(statement, afters);
}

private class RunAfterParams extends RunAfters {
RunAfterParams(Statement next, List<FrameworkMethod> afters) {
super(next, afters, null);
}

@Override
protected void invokeMethod(FrameworkMethod method) throws Throwable {
int paramCount = method.getMethod().getParameterTypes().length;
method.invokeExplosively(null, paramCount == 0 ? (Object[]) null : parameters);
}
}

@Override
Expand Down
Loading