Skip to content

Commit

Permalink
New dataset APIs (#129)
Browse files Browse the repository at this point in the history
Adding `shuffle()` and `batch()`.
  • Loading branch information
khatchad authored Jan 2, 2024
1 parent 6e53776 commit c40ccba
Show file tree
Hide file tree
Showing 7 changed files with 111 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,10 @@ public void testTf2()
// treating it as one. But, in the literal case, it should be possible to model it like the list
// tests below.
testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset2.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
Expand Down
79 changes: 56 additions & 23 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,14 @@
<new def="estimator" class="Lobject" />
<putfield class="LRoot" field="estimator" fieldType="LRoot" ref="x" value="estimator" />

<new def="data" class="Lobject" />
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />

<new def="distribute" class="Lobject" />
<putfield class="LRoot" field="distribute" fieldType="LRoot" ref="x" value="distribute" />

<new def="nn" class="Lobject" />
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
<new def="data" class="Lobject" />
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
<new def="Dataset" class="Lobject" />
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
<new def="random" class="Lobject" />
<putfield class="LRoot" field="random" fieldType="LRoot" ref="x" value="random" />
<new def="sparse" class="Lobject" />
Expand All @@ -65,6 +64,9 @@
<new def="Estimator" class="Ltensorflow/estimator/Estimator" />
<putfield class="LRoot" field="Estimator" fieldType="LRoot" ref="estimator" value="Estimator" />

<new def="Dataset" class="Ltensorflow/data/Dataset" />
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />

<new def="MirroredStrategy" class="Ltensorflow/distribute/MirroredStrategy" />
<putfield class="LRoot" field="MirroredStrategy" fieldType="LRoot" ref="distribute" value="MirroredStrategy" />

Expand All @@ -74,6 +76,9 @@
<new def="numpy_input_fn" class="Ltensorflow/estimator/numpy_input_fn" />
<putfield class="LRoot" field="numpy_input_fn" fieldType="LRoot" ref="inputs" value="numpy_input_fn" />

<new def="from_tensor_slices" class="Ltensorflow/data/Dataset/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />

<new def="reshape" class="Ltensorflow/functions/reshape" />
<putfield class="LRoot" field="reshape" fieldType="LRoot" ref="x" value="reshape" />

Expand Down Expand Up @@ -126,9 +131,6 @@
<new def="array_ops" class="Lobject" />
<putfield class="LRoot" field="array_ops" fieldType="LRoot" ref="ops" value="array_ops" />

<new def="data_ops" class="Lobject" />
<putfield class="LRoot" field="data_ops" fieldType="LRoot" ref="ops" value="data_ops" />

<new def="random_ops" class="Lobject" />
<putfield class="LRoot" field="random_ops" fieldType="LRoot" ref="ops" value="random_ops" />

Expand Down Expand Up @@ -174,10 +176,6 @@
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="x" value="ones" />
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="array_ops" value="ones" />

<new def="from_tensor_slices" class="Ltensorflow/functions/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="data_ops" value="from_tensor_slices" />

<new def="zeros" class="Ltensorflow/functions/zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="x" value="zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="array_ops" value="zeros" />
Expand Down Expand Up @@ -410,18 +408,6 @@
</method>
</class>

<class name="from_tensor_slices" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/data_ops/from_tensor_slices" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>

<class name="Variable" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/variables/Variable" />
Expand Down Expand Up @@ -804,6 +790,53 @@
</class>
</package>

<package name="tensorflow/data">
<class name="Dataset" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="shuffle" class="Ltensorflow/data/shuffle" />
<putfield class="LRoot" field="shuffle" fieldType="LRoot" ref="arg0" value="shuffle" />
<new def="batch" class="Ltensorflow/data/batch" />
<putfield class="LRoot" field="batch" fieldType="LRoot" ref="arg0" value="batch" />
<return value="arg0" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="self variant_tensor">
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>

<class name="shuffle" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#shuffle -->
<method name="do" descriptor="()LRoot;" numArgs="5" paramNames="self buffer_size seed reshuffle_each_iteration name">
<!-- FIXME: Workaround for https://github.com/wala/ML/issues/127. This method (shuffle) doesn't really return a "new" dataset but rather a modified version of the receiver. But, the receiver isn't available without a trampoline AFAIK. -->
<new def="x" class="Ltensorflow/data/Dataset" />
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
<return value="xx" />
</method>
</class>

<class name="batch" allocatable="true">
<!-- https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/data/Dataset#batch -->
<method name="do" descriptor="()LRoot;" numArgs="6" paramNames="self batch_size drop_remainder num_parallel_calls deterministic name">
<!-- FIXME: Workaround for https://github.com/wala/ML/issues/127. -->
<new def="x" class="Ltensorflow/data/Dataset" />
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
<return value="xx" />
</method>
</class>
</package>

<package name="tensorflow/data/Dataset">
<class name="from_tensor_slices" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
<new def="x" class="Ltensorflow/data/Dataset" />
<call class="Ltensorflow/data/Dataset" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="x" def="xx" />
<return value="xx" />
</method>
</class>
</package>

<package name="tensorflow/estimator/train">
<class name="train" allocatable="true">
<method name="do" descriptor="()LRoot;" numArgs="3">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ private static Set<PointsToSetVariable> getDataflowSources(
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;

// Find the potential tensor iterable creation site.
SSAInstruction def = du.getDef(eachElementGetInstruction.getUse(0));
// Find the potential tensor iterable definition.
int use = eachElementGetInstruction.getUse(0);
SSAInstruction def = du.getDef(use);

if (createsTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
Expand All @@ -133,16 +134,16 @@ private static Set<PointsToSetVariable> getDataflowSources(
}

/**
* Returns true iff the given {@link SSAInstruction} creates an iterable of tensors.
* Returns true iff the given {@link SSAInstruction} defines an iterable of tensors.
*
* @param instruction The {@link SSAInstruction} in question.
* @param node The {@link CGNode} of the function containing the given {@link SSAInstruction}.
* @param callGraph The {@link CallGraph} that includes a node corresponding to the given {@link
* SSAInstruction}.
* @param pointerAnalysis The {@link PointerAnalysis} built from the given {@link CallGraph}.
* @return True iff the given {@link SSAInstruction} creates an iterable over tensors.
* @return True iff the given {@link SSAInstruction} defines an iterable over tensors.
*/
private static boolean createsTensorIterable(
private static boolean definesTensorIterable(
SSAInstruction instruction,
CGNode node,
CallGraph callGraph,
Expand Down
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3)

for element in dataset:
c = add(element, element)
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


dataset = tf.data.Dataset(None) # This is actually illegal since this ctor is not publicly visible.

for element in dataset:
c = add(element, element)
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


dataset = tf.data.Dataset(None).shuffle(3) # This is actually illegal since this ctor is not publicly visible.

for element in dataset:
c = add(element, element)
11 changes: 11 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import tensorflow as tf


def add(a, b):
return a + b


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)

for element in dataset:
c = add(element, element)

0 comments on commit c40ccba

Please sign in to comment.