diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 8dfe8325e..6724a83b9 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -208,6 +208,7 @@ public void testTf2() testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset6.py", "add", 2, 2, 2, 3); + testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3); testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); testTf2("tf2_test_tensor_list2.py", "add", 0, 2); testTf2("tf2_test_tensor_list3.py", "add", 0, 2); diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py new file mode 100644 index 000000000..c1f8aebf3 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py @@ -0,0 +1,15 @@ +import tensorflow as tf + + +@tf.function +def add(a, b): + return a + b + + +def func(ds): + for element in ds: + c = add(element, element) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2) +func(dataset)