Skip to content

Commit

Permalink
Add pytest entrypoints (#151)
Browse files Browse the repository at this point in the history
Modeled after the turtle entrypoints. See http://pytest.org.
  • Loading branch information
khatchad authored Feb 23, 2024
1 parent 634eb70 commit f1427ae
Show file tree
Hide file tree
Showing 8 changed files with 427 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ gradle-app.setting

# Cache of project
.gradletasknamecache
/.pytest_cache/

# Eclipse Gradle plugin generated files
# Eclipse Core
Expand Down
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# From: https://docs.pytest.org/en/8.0.x/getting-started.html#group-multiple-tests-in-a-class.


# content of test_class.py
class TestClass:

def test_one(self):
x = "this"
assert "h" in x

def test_two(self):
x = "hello"
assert hasattr(x, "check")
16 changes: 16 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_class2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# From: https://docs.pytest.org/en/8.0.x/getting-started.html#group-multiple-tests-in-a-class.


# content of test_class.py
class TestClass:

def __init__(self):
pass

def test_one(self):
x = "this"
assert "h" in x

def test_two(self):
x = "hello"
assert hasattr(x, "check")
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/test_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# From https://docs.pytest.org/en/8.0.x/getting-started.html#create-your-first-test.


# content of test_sample.py
def func(x):
return x + 1


def test_answer():
assert func(3) == 5
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
package com.ibm.wala.cast.python.test;

import static com.google.common.collect.Iterables.concat;
import static java.util.Collections.singleton;

import com.ibm.wala.cast.ipa.callgraph.CAstCallGraphUtil;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ipa.callgraph.PytestEntrypointBuilder;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
import com.ibm.wala.ipa.callgraph.propagation.SSAContextInterpreter;
import com.ibm.wala.ipa.cha.ClassHierarchyException;
import com.ibm.wala.util.CancelException;
import java.io.IOException;
import java.util.Collections;
import java.util.logging.Logger;
import org.junit.Test;

public class TestCalls extends TestPythonCallGraphShape {

private static final Logger LOGGER = Logger.getLogger(TestCalls.class.getName());

protected static final Object[][] assertionsCalls1 =
new Object[][] {
new Object[] {ROOT, new String[] {"script calls1.py"}},
Expand Down Expand Up @@ -232,4 +240,112 @@ public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelEx
CG);
verifyGraphAssertions(CG, assertionsDefaultValues);
}

protected static final Object[][] PYTEST_ASSERTIONS =
new Object[][] {
new Object[] {
ROOT, new String[] {"script test_sample.py", "script test_sample.py/test_answer"}
},
new Object[] {
"script test_sample.py/test_answer", new String[] {"script test_sample.py/func"}
},
};

@Test
public void testPytestCalls()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {

PythonAnalysisEngine<?> engine =
new PythonAnalysisEngine<Void>() {
@Override
public Void performAnalysis(PropagationCallGraphBuilder builder) throws CancelException {
assert false;
return null;
}
};

engine.setModuleFiles(singleton(getScript("test_sample.py")));

PropagationCallGraphBuilder callGraphBuilder =
(PropagationCallGraphBuilder) engine.defaultCallGraphBuilder();

addPytestEntrypoints(callGraphBuilder);

CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());

CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);

verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS);
}

private static void addPytestEntrypoints(PropagationCallGraphBuilder callGraphBuilder) {
Iterable<? extends Entrypoint> defaultEntrypoints =
callGraphBuilder.getOptions().getEntrypoints();

Iterable<Entrypoint> pytestEntrypoints =
new PytestEntrypointBuilder().createEntrypoints(callGraphBuilder.getClassHierarchy());

Iterable<Entrypoint> entrypoints = concat(defaultEntrypoints, pytestEntrypoints);

callGraphBuilder.getOptions().setEntrypoints(entrypoints);

for (Entrypoint ep : callGraphBuilder.getOptions().getEntrypoints())
LOGGER.info(() -> "Using entrypoint: " + ep.getMethod().getDeclaringClass().getName() + ".");
}

protected static final Object[][] PYTEST_ASSERTIONS2 =
new Object[][] {
new Object[] {
ROOT,
new String[] {
"script test_class.py",
"script test_class.py/TestClass",
"$script test_class.py/TestClass/test_one:trampoline2",
"$script test_class.py/TestClass/test_two:trampoline2"
}
},
};

@Test
public void testPytestCalls2()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
PythonAnalysisEngine<?> engine = this.makeEngine("test_class.py");
PropagationCallGraphBuilder callGraphBuilder = engine.defaultCallGraphBuilder();

addPytestEntrypoints(callGraphBuilder);

CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());

CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);

verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS2);
}

protected static final Object[][] PYTEST_ASSERTIONS3 =
new Object[][] {
new Object[] {ROOT, new String[] {"script test_class2.py"}},
};

@Test
public void testPytestCalls3()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
PythonAnalysisEngine<?> engine = this.makeEngine("test_class2.py");
PropagationCallGraphBuilder callGraphBuilder = engine.defaultCallGraphBuilder();
addPytestEntrypoints(callGraphBuilder);
CallGraph callGraph = callGraphBuilder.makeCallGraph(callGraphBuilder.getOptions());
CAstCallGraphUtil.AVOID_DUMP = false;
CAstCallGraphUtil.dumpCG(
(SSAContextInterpreter) callGraphBuilder.getContextInterpreter(),
callGraphBuilder.getPointerAnalysis(),
callGraph);
verifyGraphAssertions(callGraph, PYTEST_ASSERTIONS3);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package com.ibm.wala.cast.python.ipa.callgraph;

import static com.ibm.wala.cast.python.types.Util.getFilename;
import static java.util.Objects.requireNonNull;

import com.ibm.wala.cast.python.loader.PythonLoader.DynamicMethodBody;
import com.ibm.wala.cast.python.loader.PythonLoader.PythonClass;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.client.AbstractAnalysisEngine.EntrypointBuilder;
import com.ibm.wala.core.util.strings.Atom;
import com.ibm.wala.ipa.callgraph.Entrypoint;
import com.ibm.wala.ipa.cha.IClassHierarchy;
import com.ibm.wala.types.MethodReference;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.util.collections.HashSetFactory;
import java.util.HashSet;
import java.util.logging.Logger;

/**
* This class represents entry points ({@link Entrypoint})s of Pytest test functions. Pytest test
* functions are those invoked by the pytest framework reflectively. The entry points can be used to
* specify entry points of a call graph.
*/
public class PytestEntrypointBuilder implements EntrypointBuilder {

private static final Logger logger = Logger.getLogger(PytestEntrypointBuilder.class.getName());

/**
* Construct pytest entrypoints for all the pytest test functions in the given scope.
*
* @throws NullPointerException If the given {@link IClassHierarchy} is null.
*/
@Override
public Iterable<Entrypoint> createEntrypoints(IClassHierarchy cha) {
requireNonNull(cha);

final HashSet<Entrypoint> result = HashSetFactory.make();

for (IClass klass : cha) {
// if the class is a pytest test case,
if (isPytestCase(klass)) {
logger.fine(() -> "Pytest case: " + klass + ".");

MethodReference methodReference =
MethodReference.findOrCreate(klass.getReference(), AstMethodReference.fnSelector);

result.add(new PytesttEntrypoint(methodReference, cha));

logger.fine(() -> "Adding test method as entry point: " + methodReference.getName() + ".");
}
}

return result::iterator;
}

/**
* Check if the given class is a Pytest test class according to: https://bit.ly/3wj8nPY.
*
* @throws NullPointerException If the given {@link IClass} is null.
* @see https://bit.ly/3wj8nPY.
*/
public static boolean isPytestCase(IClass klass) {
requireNonNull(klass);

final TypeName typeName = klass.getReference().getName();

if (typeName.toString().startsWith("Lscript ")) {
final String fileName = getFilename(typeName);
final Atom className = typeName.getClassName();

// In Ariadne, a script is an invokable entity like a function.
final boolean script = className.toString().endsWith(".py");

if (!script // it's not an invokable script.
&& (fileName.startsWith("test_")
|| fileName.endsWith("_test")) // we're inside of a "test" file,
&& !(klass instanceof PythonClass)) { // classes aren't entrypoints.
if (klass instanceof DynamicMethodBody) {
// It's a method. In Ariadne, functions are also classes.
DynamicMethodBody dmb = (DynamicMethodBody) klass;
IClass container = dmb.getContainer();
String containerName = container.getReference().getName().getClassName().toString();

if (containerName.startsWith("Test") && container instanceof PythonClass) {
// It's a test class.
PythonClass containerClass = (PythonClass) container;

final boolean hasCtor =
containerClass.getMethodReferences().stream()
.anyMatch(
mr -> {
return mr.getName().toString().equals("__init__");
});

// Test classes can't have constructors.
if (!hasCtor) {
// In Ariadne, methods are modeled as classes. Thus, a class name in this case is the
// method name.
String methodName = className.toString();

// If the method starts with "test."
if (methodName.startsWith("test")) return true;
}
}
} else if (className.toString().startsWith("test")) return true; // It's a function.
}
}

return false;
}
}
Loading

0 comments on commit f1427ae

Please sign in to comment.