Skip to content

Commit 475ed6b

Browse files
authored
Log unknown pointer key in the test code (#53)
The TF2 test code is only checking tensor parameters right now. Log if we encounter a tensor that is not a parameter.
1 parent ed9dd4f commit 475ed6b

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import java.util.Iterator;
1212
import java.util.Map;
1313
import java.util.Set;
14+
import java.util.logging.Logger;
1415
import java.util.stream.Collectors;
1516

1617
import org.junit.Test;
@@ -30,6 +31,8 @@
3031

3132
public class TestTensorflowModel extends TestPythonMLCallGraphShape {
3233

34+
private static final Logger logger = Logger.getLogger(TestTensorflowModel.class.getName());
35+
3336
@Test
3437
public void testTf1() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
3538
PythonAnalysisEngine<TensorTypeAnalysis> E = makeEngine("tf1.py");
@@ -100,32 +103,36 @@ private void testTf2(String filename, String functionName, int expectedNumberOfT
100103
// for each pointer key, tensor variable pair.
101104
analysis.forEach(p -> {
102105
PointerKey pointerKey = p.fst;
103-
LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;
104-
105-
// get the call graph node associated with the
106-
CGNode node = localPointerKey.getNode();
107-
108-
// get the method associated with the call graph node.
109-
IMethod method = node.getMethod();
110-
String methodSignature = method.getSignature();
111-
112-
// associate the method to the pointer key.
113-
methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> {
114-
if (v == null)
115-
v = new HashSet<>();
116-
v.add(localPointerKey);
117-
return v;
118-
});
119-
120-
TensorVariable tensorVariable = p.snd;
121-
122-
// associate the method to the tensor variables.
123-
methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> {
124-
if (v == null)
125-
v = new HashSet<>();
126-
v.add(tensorVariable);
127-
return v;
128-
});
106+
107+
if (pointerKey instanceof LocalPointerKey) {
108+
LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey;
109+
110+
// get the call graph node associated with the
111+
CGNode node = localPointerKey.getNode();
112+
113+
// get the method associated with the call graph node.
114+
IMethod method = node.getMethod();
115+
String methodSignature = method.getSignature();
116+
117+
// associate the method to the pointer key.
118+
methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> {
119+
if (v == null)
120+
v = new HashSet<>();
121+
v.add(localPointerKey);
122+
return v;
123+
});
124+
125+
TensorVariable tensorVariable = p.snd;
126+
127+
// associate the method to the tensor variables.
128+
methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> {
129+
if (v == null)
130+
v = new HashSet<>();
131+
v.add(tensorVariable);
132+
return v;
133+
});
134+
} else
135+
logger.warning(() -> "Encountered: " + pointerKey.getClass());
129136
});
130137

131138
// check the maps.

0 commit comments

Comments
 (0)