Skip to content

Commit

Permalink
Log unknown pointer key in the test code (#53)
Browse files Browse the repository at this point in the history
The TF2 test code is only checking tensor parameters right now. Log if we encounter a tensor that is not a parameter.
  • Loading branch information
khatchad authored Jun 21, 2023
1 parent ed9dd4f commit 475ed6b
Showing 1 changed file with 33 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;
import java.util.stream.Collectors;

import org.junit.Test;
Expand All @@ -30,6 +31,8 @@

public class TestTensorflowModel extends TestPythonMLCallGraphShape {

private static final Logger logger = Logger.getLogger(TestTensorflowModel.class.getName());

@Test
public void testTf1() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
PythonAnalysisEngine<TensorTypeAnalysis> E = makeEngine("tf1.py");
Expand Down Expand Up @@ -100,32 +103,36 @@ private void testTf2(String filename, String functionName, int expectedNumberOfT
// for each pointer key, tensor variable pair.
analysis.forEach(p -> {
PointerKey pointerKey = p.fst;
LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;

// get the call graph node associated with the
CGNode node = localPointerKey.getNode();

// get the method associated with the call graph node.
IMethod method = node.getMethod();
String methodSignature = method.getSignature();

// associate the method to the pointer key.
methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> {
if (v == null)
v = new HashSet<>();
v.add(localPointerKey);
return v;
});

TensorVariable tensorVariable = p.snd;

// associate the method to the tensor variables.
methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> {
if (v == null)
v = new HashSet<>();
v.add(tensorVariable);
return v;
});

if (pointerKey instanceof LocalPointerKey) {
LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;

// get the call graph node associated with the
CGNode node = localPointerKey.getNode();

// get the method associated with the call graph node.
IMethod method = node.getMethod();
String methodSignature = method.getSignature();

// associate the method to the pointer key.
methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> {
if (v == null)
v = new HashSet<>();
v.add(localPointerKey);
return v;
});

TensorVariable tensorVariable = p.snd;

// associate the method to the tensor variables.
methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> {
if (v == null)
v = new HashSet<>();
v.add(tensorVariable);
return v;
});
} else
logger.warning(() -> "Encountered: " + pointerKey.getClass());
});

// check the maps.
Expand Down

0 comments on commit 475ed6b

Please sign in to comment.