Skip to content

Commit

Permalink
Fix wala#140 (#82)
Browse files Browse the repository at this point in the history
* Add test cases for wala#140.

* Fix wala#140.
  • Loading branch information
khatchad committed Feb 5, 2024
1 parent f41b5f5 commit 440f391
Show file tree
Hide file tree
Showing 8 changed files with 499 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,71 @@ public void testDataset10()
test("tf2_test_dataset10.py", "add", 2, 2, 2, 3);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testDataset11()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset11.py", "f", 0, 0);
test("tf2_test_dataset11.py", "g", 1, 1, 2);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testDataset12()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset12.py", "f", 0, 0);
test("tf2_test_dataset12.py", "g", 1, 1, 2);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testDataset13()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset13.py", "f", 0, 0);
test("tf2_test_dataset13.py", "g", 1, 1, 2);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testDataset14()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset14.py", "f", 0, 0);
test("tf2_test_dataset14.py", "g", 1, 1, 2);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testDataset15()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset14.py", "f", 0, 0);
test("tf2_test_dataset14.py", "g", 1, 1, 2);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
*/
@Test
public void testTensorboardExample()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tensorboard_example.py", "summarize_weights", 0, 12);
}

@Test
public void testTensorList()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.ssa.PythonPropertyRead;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
Expand All @@ -19,6 +20,8 @@
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
import com.ibm.wala.ipa.callgraph.propagation.ConcreteTypeKey;
import com.ibm.wala.ipa.callgraph.propagation.ConstantKey;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
Expand Down Expand Up @@ -88,6 +91,12 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeA
TypeName.string2TypeName("Ltensorflow/functions/set_shape")),
AstMethodReference.fnSelector);

private static final MethodReference enumerate =
MethodReference.findOrCreate(
TypeReference.findOrCreate(
PythonTypes.pythonLoader, TypeName.string2TypeName("Lwala/builtin/enumerate")),
AstMethodReference.fnSelector);

private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();

private static Set<PointsToSetVariable> getDataflowSources(
Expand Down Expand Up @@ -275,7 +284,7 @@ private static boolean processInstructionInterprocedurally(
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();

if (reference.equals(DATASET)) {
if (reference.equals(DATASET) && isDatasetTensorElement(src, use, node, pointerAnalysis)) {
sources.add(src);
logger.info("Added dataflow source from tensor dataset: " + src + ".");
return true;
Expand All @@ -286,6 +295,77 @@ private static boolean processInstructionInterprocedurally(
return false;
}

/**
* Returns true iff the given {@link PointsToSetVariable} refers to a tensor dataset element of
* the dataset defined by the given value number in the given {@link CGNode}.
*
* @param src The {@link PointsToSetVariable} to consider.
* @param val The value in the given {@link CGNode} representing the tensor dataset.
* @param node The {@link CGNode} containing the given {@link PointsToSetVariable} and value.
* @param pointerAnalysis The {@link PointerAnalysis} that includes points-to information for the
* given {@link CGNode}.
* @return True iff src refers to a tensor dataset element defined by the dataset represented by
* val in node.
*/
private static boolean isDatasetTensorElement(
PointsToSetVariable src, int val, CGNode node, PointerAnalysis<InstanceKey> pointerAnalysis) {
SSAInstruction def = node.getDU().getDef(val);

if (def instanceof PythonInvokeInstruction) {
PythonInvokeInstruction invokeInstruction = (PythonInvokeInstruction) def;

// Check whether we are calling enumerate(), as that returns a tuple.
// Get the invoked function.
int invocationUse = invokeInstruction.getUse(0);

PointerKey invocationUsePointerKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, invocationUse);

for (InstanceKey functionInstance : pointerAnalysis.getPointsToSet(invocationUsePointerKey)) {
if (functionInstance instanceof ConcreteTypeKey) {
ConcreteTypeKey typeKey = (ConcreteTypeKey) functionInstance;
IClass type = typeKey.getType();
TypeReference typeReference = type.getReference();

if (typeReference.equals(enumerate.getDeclaringClass())) {
// it's a call to enumerate(), where the returned value is an iterator over
// tuples. Each tuple consists of the enumeration number and the dataset
// element. Check that we are not looking at the enumeration number.

PythonPropertyRead srcDef =
(PythonPropertyRead)
node.getDU().getDef(((LocalPointerKey) src.getPointerKey()).getValueNumber());

// What does the member reference point to?
PointerKey memberRefPointerKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, srcDef.getMemberRef());

for (InstanceKey memberInstance : pointerAnalysis.getPointsToSet(memberRefPointerKey)) {
ConstantKey<?> constant = (ConstantKey<?>) memberInstance;
Object value = constant.getValue();

// if it's the first tuple element.
if (value.equals(0)) {
// Now that we know it's the first tuple element, we now need to know whether it's
// the first tuple, i.e., the one returned by enumerate.
// To do that, we examine the object being referenced on the RHS.

SSAInstruction objRefDef = node.getDU().getDef(srcDef.getObjectRef());

// If the object being read is that of the dataset, we know that this is the first
// tuple read of the result of enumerate() on the dataset.
if (objRefDef instanceof PythonPropertyRead
&& ((PythonPropertyRead) objRefDef).getObjectRef() == val) return false;
}
}
}
}
}
}

return true;
}

/**
* True iff the given {@link SSAInstruction} constitutes individual elements.
*
Expand Down
Loading

0 comments on commit 440f391

Please sign in to comment.