Skip to content

Commit

Permalink
Initial tensor dataset support (#123)
Browse files Browse the repository at this point in the history
Basic fix for #89; more APIs to come.
  • Loading branch information
khatchad authored Dec 20, 2023
1 parent e18809f commit 244697e
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,15 @@ public void testTf2()
testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2);
testTf2(
"tf2_test_dataset.py",
"add",
0,
0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
// https://github.com/wala/ML/issues/89 is fixed.
// FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently
// treating it as one. But, in the literal case, it should be possible to model it like the list
// tests below.
testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
testTf2("tf2_test_tensor_list4.py", "add", 0, 0);
testTf2("tf2_test_tensor_list5.py", "add", 0, 2);
testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 3);
testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 3);
testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 3);
Expand Down
23 changes: 23 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@

<new def="nn" class="Lobject" />
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
<new def="data" class="Lobject" />
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
<new def="Dataset" class="Lobject" />
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
<new def="random" class="Lobject" />
<putfield class="LRoot" field="random" fieldType="LRoot" ref="x" value="random" />
<new def="sparse" class="Lobject" />
Expand Down Expand Up @@ -122,6 +126,9 @@
<new def="array_ops" class="Lobject" />
<putfield class="LRoot" field="array_ops" fieldType="LRoot" ref="ops" value="array_ops" />

<new def="data_ops" class="Lobject" />
<putfield class="LRoot" field="data_ops" fieldType="LRoot" ref="ops" value="data_ops" />

<new def="random_ops" class="Lobject" />
<putfield class="LRoot" field="random_ops" fieldType="LRoot" ref="ops" value="random_ops" />

Expand Down Expand Up @@ -167,6 +174,10 @@
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="x" value="ones" />
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="array_ops" value="ones" />

<new def="from_tensor_slices" class="Ltensorflow/functions/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="data_ops" value="from_tensor_slices" />

<new def="zeros" class="Ltensorflow/functions/zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="x" value="zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="array_ops" value="zeros" />
Expand Down Expand Up @@ -399,6 +410,18 @@
</method>
</class>

<class name="from_tensor_slices" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/data_ops/from_tensor_slices" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>

<class name="Variable" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/variables/Variable" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.types.AstMethodReference.fnReference;

import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
Expand All @@ -26,13 +35,22 @@
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeAnalysis> {

/** A "fake" function name in the summaries that indicates that an API produces a new tensor. */
private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data";

/**
* A "fake" function name in the summaries that indicates that an API produces a tensor iterable.
*/
private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset";

private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName());

private static final MethodReference conv2d =
Expand Down Expand Up @@ -69,31 +87,108 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeA

private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();

private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVariable> dataflow) {
private static Set<PointsToSetVariable> getDataflowSources(
Graph<PointsToSetVariable> dataflow,
CallGraph callGraph,
PointerAnalysis<InstanceKey> pointerAnalysis) {
Set<PointsToSetVariable> sources = HashSetFactory.make();
for (PointsToSetVariable src : dataflow) {
PointerKey k = src.getPointerKey();

if (k instanceof LocalPointerKey) {
LocalPointerKey kk = (LocalPointerKey) k;
int vn = kk.getValueNumber();
DefUse du = kk.getNode().getDU();
CGNode localPointerKeyNode = kk.getNode();
DefUse du = localPointerKeyNode.getDU();
SSAInstruction inst = du.getDef(vn);

if (inst instanceof SSAAbstractInvokeInstruction) {
// We potentially have a function call that generates a tensor.
SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst;

if (ni.getCallSite().getDeclaredTarget().getName().toString().equals("read_data")
if (ni.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source " + src + ".");
logger.info("Added dataflow source from tensor generator: " + src + ".");
}
} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;

// Find the potential tensor iterable creation site.
SSAInstruction def = du.getDef(eachElementGetInstruction.getUse(0));

if (createsTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
}
}
}
return sources;
}

/**
* Returns true iff the fiven {@link SSAInstruction} creates an iterable of tensors.
*
* @param instruction The {@link SSAInstruction} in question.
* @param node The {@link CGNode} of the function containing the given {@link SSAInstruction}.
* @param callGraph The {@link CallGraph} that includes a node corresponding to the given {@link
* SSAInstruction}.
* @param pointerAnalysis The {@link PointerAnalysis} built from the given {@link CallGraph}.
* @return True iff the given {@link SSAInstruction} creates an iterable over tensors.
*/
private static boolean createsTensorIterable(
SSAInstruction instruction,
CGNode node,
CallGraph callGraph,
PointerAnalysis<InstanceKey> pointerAnalysis) {
if (instruction instanceof SSAAbstractInvokeInstruction) {
SSAAbstractInvokeInstruction invocationInstruction =
(SSAAbstractInvokeInstruction) instruction;

if (invocationInstruction.getNumberOfUses() > 0) {
// What function are we calling?
int use = invocationInstruction.getUse(0);
PointerKey pointerKeyForLocal =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, use);
OrdinalSet<InstanceKey> pointsToSet = pointerAnalysis.getPointsToSet(pointerKeyForLocal);

for (InstanceKey ik : pointsToSet) {
if (ik instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) ik;
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();
MethodReference methodReference = fnReference(reference);

// Get the nodes this method calls.
Set<CGNode> iterableNodes = callGraph.getNodes(methodReference);

for (CGNode itNode : iterableNodes)
for (Iterator<CGNode> succNodes = callGraph.getSuccNodes(itNode);
succNodes.hasNext(); ) {
CGNode callee = succNodes.next();
IMethod calledMethod = callee.getMethod();

// Does this method call the synthetic "marker?"
if (calledMethod
.getName()
.toString()
.equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) {
return true;
}
}
}
}
}
}
return false;
}

@FunctionalInterface
interface SourceCallHandler {
void handleCall(CGNode src, SSAAbstractInvokeInstruction call);
Expand Down Expand Up @@ -165,7 +260,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
SlowSparseNumberedGraph.duplicate(
builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());

Set<PointsToSetVariable> sources = getDataflowSources(dataflow);
Set<PointsToSetVariable> sources =
getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());

TensorType mnistData = TensorType.mnistInput();
Map<PointsToSetVariable, TensorType> init = HashMapFactory.make();
Expand Down
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


list = [tf.ones([1, 2]), tf.ones([2, 2])]

for element in list:
c = add(element, element)
10 changes: 10 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import tensorflow as tf


def add(a, b):
return a + b


list = [tf.ones([1, 2]), tf.ones([2, 2])]

c = add(list, list)
14 changes: 14 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tensorflow as tf


def add(a, b):
return a + b


list = list()

list.append(tf.ones([1, 2]))
list.append(tf.ones([2, 2]))

for element in list:
c = add(element, element)
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


my_list = list([1, 2])

for element in my_list:
c = add(element, element)
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


my_list = list([tf.ones([1, 2]), tf.ones([2, 2])])

for element in my_list:
c = add(element, element)
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey;
import com.ibm.wala.cast.ir.ssa.AstPropertyRead;
import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions.BuiltinFunction;
import com.ibm.wala.cast.python.ir.PythonLanguage;
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.ssa.PythonPropertyRead;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IField;
Expand All @@ -37,6 +40,7 @@
import com.ibm.wala.ssa.SSAArrayStoreInstruction;
import com.ibm.wala.ssa.SSABinaryOpInstruction;
import com.ibm.wala.ssa.SSAGetInstruction;
import com.ibm.wala.ssa.SSAInstruction;
import com.ibm.wala.ssa.SymbolTable;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.TypeReference;
Expand All @@ -50,9 +54,13 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import java.util.logging.Logger;

public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder {

private static final Logger logger =
Logger.getLogger(PythonSSAPropagationCallGraphBuilder.class.getName());

public PythonSSAPropagationCallGraphBuilder(
IClassHierarchy cha,
AnalysisOptions options,
Expand Down Expand Up @@ -171,6 +179,42 @@ public String toString() {
super.visitGet(instruction);
}

@Override
public void visitPropertyRead(AstPropertyRead instruction) {
super.visitPropertyRead(instruction);

if (instruction instanceof PythonPropertyRead) {
PythonPropertyRead ppr = (PythonPropertyRead) instruction;
SSAInstruction memberRefDef = du.getDef(ppr.getMemberRef());

if (memberRefDef != null && memberRefDef instanceof EachElementGetInstruction) {
// most likely a for each "property."
final PointerKey memberRefKey = this.getPointerKeyForLocal(ppr.getMemberRef());

// for each def of the property read.
for (int i = 0; i < ppr.getNumberOfDefs(); i++) {
PointerKey defKey = this.getPointerKeyForLocal(ppr.getDef(i));

// add an assignment constraint straight away as the traversal variable won't have a
// non-empty points-to set but still may be used for a dataflow analysis.
if (this.system.newConstraint(defKey, assignOperator, memberRefKey))
logger.fine(
() ->
"Added new system constraint for global read from: "
+ defKey
+ " to: "
+ memberRefKey
+ " for instruction: "
+ instruction
+ ".");
else
logger.fine(
() -> "No constraint added for global read in instruction: " + instruction + ".");
}
}
}
}

@Override
public void visitPythonInvoke(PythonInvokeInstruction inst) {
visitInvokeInternal(inst, new DefaultInvariantComputer());
Expand Down

0 comments on commit 244697e

Please sign in to comment.