From 475ed6bf502c98d1cc6f52ebb19ef6468a9e67a3 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Wed, 21 Jun 2023 10:04:42 -0400 Subject: [PATCH] Log unknown pointer key in the test code (#53) The TF2 test code is only checking tensor parameters right now. Log if we encounter a tensor that is not a parameter. --- .../python/ml/test/TestTensorflowModel.java | 59 +++++++++++-------- 1 file changed, 33 insertions(+), 26 deletions(-) diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 8af15a597..738a76ca6 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -11,6 +11,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Set; +import java.util.logging.Logger; import java.util.stream.Collectors; import org.junit.Test; @@ -30,6 +31,8 @@ public class TestTensorflowModel extends TestPythonMLCallGraphShape { + private static final Logger logger = Logger.getLogger(TestTensorflowModel.class.getName()); + @Test public void testTf1() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { PythonAnalysisEngine E = makeEngine("tf1.py"); @@ -100,32 +103,36 @@ private void testTf2(String filename, String functionName, int expectedNumberOfT // for each pointer key, tensor variable pair. analysis.forEach(p -> { PointerKey pointerKey = p.fst; - LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; - - // get the call graph node associated with the - CGNode node = localPointerKey.getNode(); - - // get the method associated with the call graph node. - IMethod method = node.getMethod(); - String methodSignature = method.getSignature(); - - // associate the method to the pointer key. - methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> { - if (v == null) - v = new HashSet<>(); - v.add(localPointerKey); - return v; - }); - - TensorVariable tensorVariable = p.snd; - - // associate the method to the tensor variables. - methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> { - if (v == null) - v = new HashSet<>(); - v.add(tensorVariable); - return v; - }); + + if (pointerKey instanceof LocalPointerKey) { + LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; + + // get the call graph node associated with the + CGNode node = localPointerKey.getNode(); + + // get the method associated with the call graph node. + IMethod method = node.getMethod(); + String methodSignature = method.getSignature(); + + // associate the method to the pointer key. + methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> { + if (v == null) + v = new HashSet<>(); + v.add(localPointerKey); + return v; + }); + + TensorVariable tensorVariable = p.snd; + + // associate the method to the tensor variables. + methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> { + if (v == null) + v = new HashSet<>(); + v.add(tensorVariable); + return v; + }); + } else + logger.warning(() -> "Encountered: " + pointerKey.getClass()); }); // check the maps.