|
11 | 11 | import java.util.Iterator;
|
12 | 12 | import java.util.Map;
|
13 | 13 | import java.util.Set;
|
| 14 | +import java.util.logging.Logger; |
14 | 15 | import java.util.stream.Collectors;
|
15 | 16 |
|
16 | 17 | import org.junit.Test;
|
|
30 | 31 |
|
31 | 32 | public class TestTensorflowModel extends TestPythonMLCallGraphShape {
|
32 | 33 |
|
| 34 | + private static final Logger logger = Logger.getLogger(TestTensorflowModel.class.getName()); |
| 35 | + |
33 | 36 | @Test
|
34 | 37 | public void testTf1() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
|
35 | 38 | PythonAnalysisEngine<TensorTypeAnalysis> E = makeEngine("tf1.py");
|
@@ -100,32 +103,36 @@ private void testTf2(String filename, String functionName, int expectedNumberOfT
|
100 | 103 | // for each pointer key, tensor variable pair.
|
101 | 104 | analysis.forEach(p -> {
|
102 | 105 | PointerKey pointerKey = p.fst;
|
103 |
| - LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; |
104 |
| - |
105 |
| - // get the call graph node associated with the |
106 |
| - CGNode node = localPointerKey.getNode(); |
107 |
| - |
108 |
| - // get the method associated with the call graph node. |
109 |
| - IMethod method = node.getMethod(); |
110 |
| - String methodSignature = method.getSignature(); |
111 |
| - |
112 |
| - // associate the method to the pointer key. |
113 |
| - methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> { |
114 |
| - if (v == null) |
115 |
| - v = new HashSet<>(); |
116 |
| - v.add(localPointerKey); |
117 |
| - return v; |
118 |
| - }); |
119 |
| - |
120 |
| - TensorVariable tensorVariable = p.snd; |
121 |
| - |
122 |
| - // associate the method to the tensor variables. |
123 |
| - methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> { |
124 |
| - if (v == null) |
125 |
| - v = new HashSet<>(); |
126 |
| - v.add(tensorVariable); |
127 |
| - return v; |
128 |
| - }); |
| 106 | + |
| 107 | + if (pointerKey instanceof LocalPointerKey) { |
| 108 | + LocalPointerKey localPointerKey = (LocalPointerKey) pointerKey; |
| 109 | + |
| 110 | + // get the call graph node associated with the |
| 111 | + CGNode node = localPointerKey.getNode(); |
| 112 | + |
| 113 | + // get the method associated with the call graph node. |
| 114 | + IMethod method = node.getMethod(); |
| 115 | + String methodSignature = method.getSignature(); |
| 116 | + |
| 117 | + // associate the method to the pointer key. |
| 118 | + methodSignatureToPointerKeys.compute(methodSignature, (k, v) -> { |
| 119 | + if (v == null) |
| 120 | + v = new HashSet<>(); |
| 121 | + v.add(localPointerKey); |
| 122 | + return v; |
| 123 | + }); |
| 124 | + |
| 125 | + TensorVariable tensorVariable = p.snd; |
| 126 | + |
| 127 | + // associate the method to the tensor variables. |
| 128 | + methodSignatureToTensorVariables.compute(methodSignature, (k, v) -> { |
| 129 | + if (v == null) |
| 130 | + v = new HashSet<>(); |
| 131 | + v.add(tensorVariable); |
| 132 | + return v; |
| 133 | + }); |
| 134 | + } else |
| 135 | + logger.warning(() -> "Encountered: " + pointerKey.getClass()); |
129 | 136 | });
|
130 | 137 |
|
131 | 138 | // check the maps.
|
|
0 commit comments