Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using XML summaries to add some TF2 APIs #64

Merged
merged 4 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,12 @@ public void testTf1()
PythonSSAPropagationCallGraphBuilder builder = E.defaultCallGraphBuilder();
CallGraph CG = builder.makeCallGraph(builder.getOptions());

// CAstCallGraphUtil.AVOID_DUMP = false;
// CAstCallGraphUtil.AVOID_DUMP = false;
//
// CAstCallGraphUtil.dumpCG(((SSAPropagationCallGraphBuilder)builder).getCFAContextInterpreter(), builder.getPointerAnalysis(), CG);
// CAstCallGraphUtil.dumpCG(((SSAPropagationCallGraphBuilder)builder).getCFAContextInterpreter(),
// builder.getPointerAnalysis(), CG);

// System.err.println(CG);
// System.err.println(CG);

Collection<CGNode> nodes = getNodes(CG, "script tf1.py/model_fn");
assert !nodes.isEmpty() : "model_fn should be called";
Expand All @@ -62,42 +63,67 @@ public void testTf1()
@Test
public void testTf2()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
testTf2("tf2.py", "add", 2, 2, 3);
testTf2("tf2b.py", "add", 2, 2, 3);
testTf2("tf2c.py", "add", 2, 2, 3);
testTf2("tf2d.py", "add", 2, 2, 3);
testTf2("tf2e.py", "add", 2, 2, 3);
testTf2("tf2f.py", "add", 2, 2, 3);
testTf2("tf2g.py", "add", 2, 2, 3);
testTf2("tf2h.py", "add", 2, 2, 3);
testTf2("tf2i.py", "add", 2, 2, 3);
testTf2("tf2j.py", "add", 2, 2, 3);
testTf2("tf2k.py", "add", 2, 2, 3);
testTf2("tf2l.py", "add", 2, 2, 3);
testTf2("tf2m.py", "add", 2, 2, 3);
// TODO: Uncomment below test when https://github.com/wala/ML/issues/49 is fixed.
// testTf2("tf2n.py", "func2", 1, 2);
testTf2("tf2o.py", "add", 2, 2, 3);
testTf2("tf2p.py", "value_index", 2, 2, 3);
testTf2("tf2.py", "add", 2, 3, 2, 3);
testTf2("tf2b.py", "add", 2, 3, 2, 3);
testTf2("tf2c.py", "add", 2, 4, 2, 3);
testTf2("tf2d.py", "add", 2, 3, 2, 3);
testTf2("tf2d2.py", "add", 2, 3, 2, 3);
testTf2("tf2d3.py", "add", 2, 3, 2, 3);
testTf2("tf2d4.py", "add", 2, 4, 2, 3);
testTf2("tf2e.py", "add", 2, 3, 2, 3);
testTf2("tf2e2.py", "add", 2, 4, 2, 3);
testTf2("tf2e3.py", "add", 2, 3, 2, 3);
testTf2("tf2e4.py", "add", 2, 4, 2, 3);
testTf2("tf2e5.py", "add", 2, 3, 2, 3);
testTf2("tf2e6.py", "add", 2, 3, 2, 3);
testTf2("tf2e7.py", "add", 2, 3, 2, 3);
testTf2("tf2f.py", "add", 2, 3, 2, 3);
testTf2("tf2f2.py", "add", 2, 4, 2, 3);
testTf2("tf2f3.py", "add", 2, 4, 2, 3);
testTf2("tf2g.py", "add", 2, 3, 2, 3);
testTf2("tf2g2.py", "add", 2, 4, 2, 3);
testTf2("tf2h.py", "add", 2, 3, 2, 3);
testTf2("tf2h2.py", "add", 2, 4, 2, 3);
testTf2("tf2i.py", "add", 2, 3, 2, 3);
testTf2("tf2i2.py", "add", 2, 4, 2, 3);
testTf2("tf2j.py", "add", 2, 3, 2, 3);
testTf2("tf2j2.py", "add", 2, 4, 2, 3);
testTf2("tf2k.py", "add", 2, 3, 2, 3);
testTf2("tf2k2.py", "add", 2, 4, 2, 3);
testTf2("tf2l.py", "add", 2, 3, 2, 3);
testTf2("tf2l2.py", "add", 2, 4, 2, 3);
testTf2("tf2m.py", "add", 2, 3, 2, 3);
testTf2("tf2m2.py", "add", 2, 4, 2, 3);
testTf2("tf2n.py", "func2", 1, 4, 2);
testTf2("tf2n2.py", "func2", 1, 4, 2);
testTf2("tf2n3.py", "func2", 1, 4, 2);
testTf2("tf2o.py", "add", 2, 3, 2, 3);
testTf2("tf2o2.py", "add", 2, 4, 2, 3);
testTf2("tf2p.py", "value_index", 2, 4, 2, 3);
testTf2("tf2p2.py", "value_index", 2, 4, 2, 3);
testTf2("tf2q.py", "add", 2, 3, 2, 3);
testTf2("tf2r.py", "add", 2, 3, 2, 3);
// TODO: Uncomment below test when https://github.com/wala/ML/issues/65 is fixed.
// testTf2("tf2s.py", "add", 2, 3, 2, 3);
}

private void testTf2(
String filename,
String functionName,
int expectedNumberOfTensorParameters,
int... expectedValueNumbers)
int expectedNumberOfTensorVariables,
int... expectedTensorParameterValueNumbers)
throws ClassHierarchyException, CancelException, IOException {
PythonAnalysisEngine<TensorTypeAnalysis> E = makeEngine(filename);
PythonSSAPropagationCallGraphBuilder builder = E.defaultCallGraphBuilder();

CallGraph CG = builder.makeCallGraph(builder.getOptions());
assertNotNull(CG);

// CAstCallGraphUtil.AVOID_DUMP = false;
// CAstCallGraphUtil.dumpCG(builder.getCFAContextInterpreter(), builder.getPointerAnalysis(),
// CG);

// System.err.println(CG);
// CAstCallGraphUtil.AVOID_DUMP = false;
// CAstCallGraphUtil.dumpCG(((SSAPropagationCallGraphBuilder)builder).getCFAContextInterpreter(),
// builder.getPointerAnalysis(), CG);
// System.err.println(CG);

TensorTypeAnalysis analysis = E.performAnalysis(builder);

Expand Down Expand Up @@ -145,8 +171,8 @@ private void testTf2(
});

// check the maps.
assertEquals(expectedNumberOfTensorParameters, methodSignatureToPointerKeys.size());
assertEquals(expectedNumberOfTensorParameters, methodSignatureToTensorVariables.size());
assertEquals(expectedNumberOfTensorVariables, methodSignatureToPointerKeys.size());
assertEquals(expectedNumberOfTensorVariables, methodSignatureToTensorVariables.size());

final String functionSignature = "script " + filename + "." + functionName + ".do()LRoot;";

Expand All @@ -162,8 +188,9 @@ private void testTf2(
.map(LocalPointerKey::getValueNumber)
.collect(Collectors.toSet());

assertEquals(expectedValueNumbers.length, actualValueNumberSet.size());
Arrays.stream(expectedValueNumbers).forEach(ev -> actualValueNumberSet.contains(ev));
assertEquals(expectedTensorParameterValueNumbers.length, actualValueNumberSet.size());
Arrays.stream(expectedTensorParameterValueNumbers)
.forEach(ev -> actualValueNumberSet.contains(ev));

// get the tensor variables for the function.
Set<TensorVariable> functionTensors = methodSignatureToTensorVariables.get(functionSignature);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ public void testParsing()
CallGraph CG = builder.makeCallGraph(builder.getOptions());
assertNotNull(CG);

// CAstCallGraphUtil.AVOID_DUMP = false;
// CAstCallGraphUtil.dumpCG(builder.getCFAContextInterpreter(), builder.getPointerAnalysis(),
// CAstCallGraphUtil.AVOID_DUMP = false;
// CAstCallGraphUtil.dumpCG(builder.getCFAContextInterpreter(), builder.getPointerAnalysis(),
// CG);

// System.err.println(CG);
// System.err.println(CG);

TensorTypeAnalysis analysis = E.performAnalysis(builder);

Expand Down Expand Up @@ -79,8 +79,8 @@ public void testParsing()
logger.warning(() -> "Encountered pointer key type: " + pointerKey.getClass() + ".");
});

// we should have two methods.
assertEquals(2, methodSignatureToPointerKeys.size());
// we should have 3 methods.
assertEquals(3, methodSignatureToPointerKeys.size());

final String addFunctionSignature = "script " + filename + ".add.do()LRoot;";

Expand Down
Loading