Skip to content

Commit

Permalink
Fix tf.range (#152)
Browse files Browse the repository at this point in the history
Not only does `tf.range()` "generate" a new tensor, i.e., the result
value itself is a tensor, inside that tensor are more, newly generated
tensors. Here, we model it as a one element sequence, which should
suffice for analysis purposes.
  • Loading branch information
khatchad authored Feb 27, 2024
1 parent cc6d83c commit ce82d05
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,18 @@ public void testRelu()
test("tf2_test_relu.py", "f", 1, 1, 2);
}

@Test
public void testTFRange()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_tf_range.py", "f", 1, 1, 2);
}

@Test
public void testTFRange2()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_tf_range2.py", "f", 1, 1, 2);
}

private void test(
String filename,
String functionName,
Expand Down
4 changes: 3 additions & 1 deletion com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,9 @@
</class>
<class name="range" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/math_ops/range" />
<new def="x" class="Llist" />
<call class="Ltensorflow/functions/constant" name="do" descriptor="()LRoot;" type="virtual" arg0="1" def="y" />
<putfield class="LRoot" field="0" fieldType="LRoot" ref="x" value="y" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="4" paramNames="limit delta dtype name">
Expand Down
17 changes: 17 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tf_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# From: https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/range#for_example

import tensorflow as tf


def f(a):
pass


start = 3
limit = 18
delta = 3

r = tf.range(start, limit, delta)

for i in r:
f(i)
13 changes: 13 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_tf_range2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# From: https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/range#for_example

import tensorflow as tf


def f(a):
pass


r = [tf.constant(3), tf.constant(6), tf.constant(9), tf.constant(12), tf.constant(15)]

for i in r:
f(i)

0 comments on commit ce82d05

Please sign in to comment.