diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 7669febb1..3b8dd06c5 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -293,11 +293,9 @@ public void testR() @Test public void testS() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { - test( - "tf2s.py", "add", 0, - 0); // NOTE: Set the expected number of tensor parameters, variables, and tensor parameter - // value numbers to 2, 3, and 2 and 3, respectively, when - // https://github.com/wala/ML/issues/65 is fixed. + // NOTE: Set the expected number of tensor variables to 3 once + // https://github.com/wala/ML/issues/135 is fixed. + test("tf2s.py", "add", 2, 2, 2, 3); } @Test @@ -1422,6 +1420,75 @@ public void testTFRange3() test("test_tf_range.py", "f", 1, 1, 2); } + @Test + public void testImport() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import.py", "f", 1, 1, 2); + } + + @Test + public void testImport2() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import2.py", "f", 1, 1, 2); + } + + @Test + public void testImport3() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import3.py", "f", 1, 2, 2); + test("tf2_test_import3.py", "g", 1, 1, 2); + } + + @Test + public void testImport4() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import4.py", "f", 1, 2, 2); + test("tf2_test_import4.py", "g", 1, 1, 2); + } + + @Test + public void testImport5() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import5.py", "f", 0, 1); + test("tf2_test_import5.py", "g", 1, 1, 2); + } + + @Test + public void testImport6() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import6.py", "f", 0, 1); + test("tf2_test_import6.py", "g", 1, 1, 2); + } + + /** + * This is an invalid case. If there are no wildcard imports, we should resolve them like they + * are. + */ + @Test + public void testImport7() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import7.py", "f", 0, 0); + test("tf2_test_import7.py", "g", 0, 0); + } + + /** + * This is an invalid case. If there are no wildcard imports, we should resolve them like they + * are. + */ + @Test + public void testImport8() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import8.py", "f", 0, 0); + test("tf2_test_import8.py", "g", 0, 0); + } + + @Test + public void testImport9() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test("tf2_test_import9.py", "f", 1, 1, 2); + test("tf2_test_import9.py", "g", 1, 1, 2); + } + private void test( String filename, String functionName, diff --git a/com.ibm.wala.cast.python.test/data/multi3.py b/com.ibm.wala.cast.python.test/data/multi3.py new file mode 100644 index 000000000..a0994cb3d --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/multi3.py @@ -0,0 +1,4 @@ +from multi2 import silly, x + +print(silly(1)) +print(x) diff --git a/com.ibm.wala.cast.python.test/data/multi4.py b/com.ibm.wala.cast.python.test/data/multi4.py new file mode 100644 index 000000000..eee2499dc --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/multi4.py @@ -0,0 +1,4 @@ +from multi2 import x, silly + +print(silly(1)) +print(x) diff --git a/com.ibm.wala.cast.python.test/data/multi5.py b/com.ibm.wala.cast.python.test/data/multi5.py new file mode 100644 index 000000000..51dabbabf --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/multi5.py @@ -0,0 +1,4 @@ +from multi2 import * + +print(silly(1)) +print(x) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import.py b/com.ibm.wala.cast.python.test/data/tf2_test_import.py new file mode 100644 index 000000000..223a3846f --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import.py @@ -0,0 +1,10 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import ones, Tensor + + +def f(a): + assert isinstance(a, Tensor) + + +f(ones([1, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import2.py b/com.ibm.wala.cast.python.test/data/tf2_test_import2.py new file mode 100644 index 000000000..4cc6c367a --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import2.py @@ -0,0 +1,10 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import * + + +def f(a): + assert isinstance(a, Tensor) + + +f(ones([1, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import3.py b/com.ibm.wala.cast.python.test/data/tf2_test_import3.py new file mode 100644 index 000000000..24420e243 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import3.py @@ -0,0 +1,15 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import ones, Tensor + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + assert isinstance(a, Tensor) + g(ones([1, 2])) + + +f(ones([1, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import4.py b/com.ibm.wala.cast.python.test/data/tf2_test_import4.py new file mode 100644 index 000000000..be0fa979a --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import4.py @@ -0,0 +1,15 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import * + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + assert isinstance(a, Tensor) + g(ones([1, 2])) + + +f(ones([1, 2])) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import5.py b/com.ibm.wala.cast.python.test/data/tf2_test_import5.py new file mode 100644 index 000000000..556e106ba --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import5.py @@ -0,0 +1,14 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import ones, Tensor + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + g(ones([1, 2])) + + +f(5) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import6.py b/com.ibm.wala.cast.python.test/data/tf2_test_import6.py new file mode 100644 index 000000000..68028d1a1 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import6.py @@ -0,0 +1,14 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import * + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + g(ones([1, 2])) + + +f(5) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import7.py b/com.ibm.wala.cast.python.test/data/tf2_test_import7.py new file mode 100644 index 000000000..bdbe7b8ab --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import7.py @@ -0,0 +1,13 @@ +# Test https://github.com/wala/ML/issues/65. +# This is an invalid case. No wildcard import; we shouldn't present that there is one. + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + g(ones([1, 2])) + + +f(5) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import8.py b/com.ibm.wala.cast.python.test/data/tf2_test_import8.py new file mode 100644 index 000000000..c86c5587e --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import8.py @@ -0,0 +1,15 @@ +# Test https://github.com/wala/ML/issues/65. +# This is an invalid case. No wildcard import; we shouldn't present that there is one. + +from tensorflow import Tensor + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + g(ones([1, 2])) + + +f(5) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_import9.py b/com.ibm.wala.cast.python.test/data/tf2_test_import9.py new file mode 100644 index 000000000..3df7b350b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_import9.py @@ -0,0 +1,15 @@ +# Test https://github.com/wala/ML/issues/65. + +from tensorflow import * + + +def g(a): + assert isinstance(a, Tensor) + + +def f(a): + assert isinstance(a, Tensor) + g(a) + + +f(ones([1, 2])) diff --git a/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestMulti.java b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestMulti.java index 7792b7e35..2c4a33613 100644 --- a/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestMulti.java +++ b/com.ibm.wala.cast.python.test/source/com/ibm/wala/cast/python/test/TestMulti.java @@ -42,4 +42,64 @@ public void testMulti1() CAstCallGraphUtil.dumpCG( (SSAContextInterpreter) builder.getContextInterpreter(), builder.getPointerAnalysis(), CG); } + + protected static final Object[][] assertionsMulti3 = + new Object[][] { + new Object[] {ROOT, new String[] {"script multi3.py", "script multi2.py"}}, + new Object[] {"script multi3.py", new String[] {"script multi2.py/silly"}}, + new Object[] {"script multi2.py/silly", new String[] {"script multi2.py/silly/inner"}}, + }; + + @Test + public void testMulti3() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + PythonAnalysisEngine engine = makeEngine("multi2.py", "multi3.py"); + PropagationCallGraphBuilder builder = + (PropagationCallGraphBuilder) engine.defaultCallGraphBuilder(); + CallGraph CG = builder.makeCallGraph(engine.getOptions(), new NullProgressMonitor()); + CAstCallGraphUtil.AVOID_DUMP = false; + CAstCallGraphUtil.dumpCG( + (SSAContextInterpreter) builder.getContextInterpreter(), builder.getPointerAnalysis(), CG); + verifyGraphAssertions(CG, assertionsMulti3); + } + + protected static final Object[][] assertionsMulti4 = + new Object[][] { + new Object[] {ROOT, new String[] {"script multi4.py", "script multi2.py"}}, + new Object[] {"script multi4.py", new String[] {"script multi2.py/silly"}}, + new Object[] {"script multi2.py/silly", new String[] {"script multi2.py/silly/inner"}}, + }; + + @Test + public void testMulti4() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + PythonAnalysisEngine engine = makeEngine("multi2.py", "multi4.py"); + PropagationCallGraphBuilder builder = + (PropagationCallGraphBuilder) engine.defaultCallGraphBuilder(); + CallGraph CG = builder.makeCallGraph(engine.getOptions(), new NullProgressMonitor()); + CAstCallGraphUtil.AVOID_DUMP = false; + CAstCallGraphUtil.dumpCG( + (SSAContextInterpreter) builder.getContextInterpreter(), builder.getPointerAnalysis(), CG); + verifyGraphAssertions(CG, assertionsMulti4); + } + + protected static final Object[][] assertionsMulti5 = + new Object[][] { + new Object[] {ROOT, new String[] {"script multi5.py", "script multi2.py"}}, + new Object[] {"script multi5.py", new String[] {"script multi2.py/silly"}}, + new Object[] {"script multi2.py/silly", new String[] {"script multi2.py/silly/inner"}}, + }; + + @Test + public void testMulti5() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + PythonAnalysisEngine engine = makeEngine("multi2.py", "multi5.py"); + PropagationCallGraphBuilder builder = + (PropagationCallGraphBuilder) engine.defaultCallGraphBuilder(); + CallGraph CG = builder.makeCallGraph(engine.getOptions(), new NullProgressMonitor()); + CAstCallGraphUtil.AVOID_DUMP = false; + CAstCallGraphUtil.dumpCG( + (SSAContextInterpreter) builder.getContextInterpreter(), builder.getPointerAnalysis(), CG); + verifyGraphAssertions(CG, assertionsMulti5); + } } diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java index 59eaa0743..403c3017b 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java @@ -10,18 +10,23 @@ *****************************************************************************/ package com.ibm.wala.cast.python.ipa.callgraph; +import com.google.common.collect.Maps; import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder; import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey; +import com.ibm.wala.cast.ir.ssa.AstGlobalRead; +import com.ibm.wala.cast.ir.ssa.AstPropertyRead; import com.ibm.wala.cast.python.ir.PythonLanguage; import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor; import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; import com.ibm.wala.cast.python.types.PythonTypes; import com.ibm.wala.classLoader.IClass; import com.ibm.wala.classLoader.IField; +import com.ibm.wala.classLoader.NewSiteReference; import com.ibm.wala.core.util.strings.Atom; import com.ibm.wala.fixpoint.AbstractOperator; import com.ibm.wala.ipa.callgraph.AnalysisOptions; import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.CallGraph; import com.ibm.wala.ipa.callgraph.IAnalysisCacheView; import com.ibm.wala.ipa.callgraph.propagation.AbstractFieldPointerKey; import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; @@ -36,8 +41,13 @@ import com.ibm.wala.ssa.SSAArrayStoreInstruction; import com.ibm.wala.ssa.SSABinaryOpInstruction; import com.ibm.wala.ssa.SSAGetInstruction; +import com.ibm.wala.ssa.SSAInstruction; +import com.ibm.wala.ssa.SSAInvokeInstruction; import com.ibm.wala.ssa.SymbolTable; +import com.ibm.wala.types.Descriptor; import com.ibm.wala.types.FieldReference; +import com.ibm.wala.types.MethodReference; +import com.ibm.wala.types.TypeName; import com.ibm.wala.types.TypeReference; import com.ibm.wala.util.collections.Pair; import com.ibm.wala.util.intset.IntIterator; @@ -45,8 +55,14 @@ import com.ibm.wala.util.intset.IntSetUtil; import com.ibm.wala.util.intset.MutableIntSet; import com.ibm.wala.util.intset.OrdinalSet; +import java.util.ArrayDeque; import java.util.Arrays; import java.util.Collection; +import java.util.Deque; +import java.util.Iterator; +import java.util.Map; +import java.util.Objects; +import java.util.Set; import java.util.logging.Logger; public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder { @@ -98,6 +114,19 @@ protected boolean sameMethod(CGNode opNode, String definingMethod) { public static class PythonConstraintVisitor extends AstConstraintVisitor implements PythonInstructionVisitor { + private static final String GLOBAL_IDENTIFIER = "global"; + + private static final String IMPORT_WILDCARD_CHARACTER = "*"; + + private static final Atom IMPORT_FUNCTION_NAME = Atom.findOrCreateAsciiAtom("import"); + + /** + * A mapping of script names to wildcard imports. We use a {@link Deque} here because we want to + * always examine the last (front of the queue) encountered wildcard import library for known + * names assuming that import instructions are traversed from first to last. + */ + private static Map> scriptToWildcardImports = Maps.newHashMap(); + @Override protected PythonSSAPropagationCallGraphBuilder getBuilder() { return (PythonSSAPropagationCallGraphBuilder) builder; @@ -183,6 +212,171 @@ public void visitArrayLoad(SSAArrayLoadInstruction inst) { public void visitArrayStore(SSAArrayStoreInstruction inst) { newFieldWrite(node, inst.getArrayRef(), inst.getIndex(), inst.getValue()); } + + @Override + public void visitPropertyRead(AstPropertyRead instruction) { + super.visitPropertyRead(instruction); + + int memberRef = instruction.getMemberRef(); + + if (this.ir.getSymbolTable().isConstant(memberRef)) { + Object constantValue = this.ir.getSymbolTable().getConstantValue(memberRef); + + if (Objects.equals(constantValue, IMPORT_WILDCARD_CHARACTER)) { + // We have a wildcard. + logger.fine("Detected wildcard for " + memberRef + " in " + instruction + "."); + + int objRef = instruction.getObjectRef(); + logger.fine("Seeing if " + objRef + " refers to an import."); + + SSAInstruction def = this.du.getDef(objRef); + logger.finer("Found definition: " + def + "."); + + TypeName scriptTypeName = + this.ir.getMethod().getReference().getDeclaringClass().getName(); + assert scriptTypeName.getPackage() == null + : "Import statement should only occur at the top-level script."; + + String scriptName = scriptTypeName.getClassName().toString(); + + if (def instanceof SSAInvokeInstruction) { + // Library case. + SSAInvokeInstruction invokeInstruction = (SSAInvokeInstruction) def; + MethodReference declaredTarget = invokeInstruction.getDeclaredTarget(); + Atom declaredTargetName = declaredTarget.getName(); + + if (declaredTargetName.equals(IMPORT_FUNCTION_NAME)) { + // It's an import "statement" importing a library. + logger.fine("Found library import statement in: " + scriptTypeName + "."); + + logger.info( + "Adding: " + + declaredTarget.getDeclaringClass().getName().getClassName() + + " to wildcard imports for: " + + scriptName + + "."); + + // Add the library to the script's queue of wildcard imports. + scriptToWildcardImports.compute( + scriptName, + (k, v) -> { + if (v == null) { + Deque deque = new ArrayDeque<>(); + deque.push(declaredTarget); + return deque; + } else { + v.push(declaredTarget); + return v; + } + }); + } + } else if (def instanceof SSAGetInstruction) { + // We are importing from a script. + SSAGetInstruction getInstruction = (SSAGetInstruction) def; + FieldReference declaredField = getInstruction.getDeclaredField(); + Atom fieldName = declaredField.getName(); + String strippedFieldName = + fieldName.toString().substring(GLOBAL_IDENTIFIER.length() + 1); + TypeReference typeReference = + TypeReference.findOrCreate(PythonTypes.pythonLoader, "L" + strippedFieldName); + MethodReference methodReference = + MethodReference.findOrCreate( + typeReference, + Atom.findOrCreateAsciiAtom("do"), + Descriptor.findOrCreate(null, PythonTypes.rootTypeName)); + + logger.info( + "Adding: " + + methodReference.getDeclaringClass().getName().getClassName() + + " to wildcard imports for: " + + scriptName + + "."); + + // Add the script to the queue of this script's wildcard imports. + scriptToWildcardImports.compute( + scriptName, + (k, v) -> { + if (v == null) { + Deque deque = new ArrayDeque<>(); + deque.push(methodReference); + return deque; + } else { + v.push(methodReference); + return v; + } + }); + } + } + } + } + + @Override + public void visitAstGlobalRead(AstGlobalRead instruction) { + super.visitAstGlobalRead(instruction); + + TypeName scriptTypeName = this.ir.getMethod().getReference().getDeclaringClass().getName(); + + String scriptName = + (scriptTypeName.getPackage() == null + ? scriptTypeName.getClassName() + : scriptTypeName.getPackage()) + .toString(); + logger.finer("Script name is: " + scriptName + "."); + + // Are there any wildcard imports for this script? + if (scriptToWildcardImports.containsKey(scriptName)) { + logger.info("Found wildcard imports in " + scriptName + " for " + instruction + "."); + + Deque deque = scriptToWildcardImports.get(scriptName); + + for (MethodReference importMethodReference : deque) { + logger.fine( + "Library with wildcard import is: " + + importMethodReference.getDeclaringClass().getName().getClassName() + + "."); + + String declaredFieldName = getStrippedDeclaredFieldName(instruction); + logger.fine("Examining global: " + declaredFieldName + " for wildcard import."); + + CallGraph callGraph = this.getBuilder().getCallGraph(); + Set nodes = callGraph.getNodes(importMethodReference); + + for (CGNode n : nodes) { + for (Iterator nit = n.iterateNewSites(); nit.hasNext(); ) { + NewSiteReference newSiteReference = nit.next(); + + String name = newSiteReference.getDeclaredType().getName().getClassName().toString(); + logger.finest("Examining: " + name + "."); + + if (name.equals(declaredFieldName)) { + logger.info("Found wildcard import for: " + name + "."); + + PointerKey def = getPointerKeyForLocal(instruction.getDef()); + assert def != null; + + InstanceKey instanceKey = this.getInstanceKeyForAllocation(newSiteReference); + + if (this.system.newConstraint(def, instanceKey)) { + logger.fine("Added constraint that: " + def + " gets: " + instanceKey + "."); + return; + } + } + } + } + } + } + } + + private static String getStrippedDeclaredFieldName(SSAGetInstruction instruction) { + String declaredFieldName = instruction.getDeclaredField().getName().toString(); + assert declaredFieldName.startsWith(GLOBAL_IDENTIFIER + " "); + + // Remove the global identifier. + String strippedDeclaredFieldName = + declaredFieldName.substring( + (GLOBAL_IDENTIFIER + " ").length(), declaredFieldName.length()); + return strippedDeclaredFieldName; + } } @Override