diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index 8a8ed6442..d81beec70 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -207,6 +207,8 @@ public void testTf2() testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3); testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3); + testTf2("tf2_test_dataset6.py", "add", 2, 2, 2, 3); + testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3); testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); testTf2("tf2_test_tensor_list2.py", "add", 0, 2); testTf2("tf2_test_tensor_list3.py", "add", 0, 2); @@ -218,8 +220,8 @@ public void testTf2() testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 3); testTf2("tf2_test_callbacks.py", "replica_fn", 1, 3, 2); testTf2("tf2_test_callbacks2.py", "replica_fn", 1, 4, 2); - testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 7); - testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 7); + testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 2); + testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 2); } private void testTf2( @@ -302,27 +304,23 @@ private void testTf2( methodSignatureToPointerKeys.getOrDefault(functionSignature, Collections.emptySet()); // check tensor parameters. - assertEquals(expectedNumberOfTensorParameters, functionPointerKeys.size()); + assertEquals( + expectedNumberOfTensorParameters, + functionPointerKeys.stream().filter(LocalPointerKey::isParameter).count()); // check value numbers. - Set actualValueNumberSet = + Set actualParameterValueNumberSet = functionPointerKeys.stream() + .filter(LocalPointerKey::isParameter) .map(LocalPointerKey::getValueNumber) .collect(Collectors.toSet()); - assertEquals(expectedTensorParameterValueNumbers.length, actualValueNumberSet.size()); + assertEquals(expectedTensorParameterValueNumbers.length, actualParameterValueNumberSet.size()); Arrays.stream(expectedTensorParameterValueNumbers) .forEach( ev -> assertTrue( - "Expecting " + actualValueNumberSet + " to contain " + ev + ".", - actualValueNumberSet.contains(ev))); - - // get the tensor variables for the function. - Set functionTensors = - methodSignatureToTensorVariables.getOrDefault(functionSignature, Collections.emptySet()); - - // check tensor parameters. - assertEquals(expectedNumberOfTensorParameters, functionTensors.size()); + "Expecting " + actualParameterValueNumberSet + " to contain " + ev + ".", + actualParameterValueNumberSet.contains(ev))); } } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 1fbac142c..7761ea4a6 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -1,5 +1,6 @@ package com.ibm.wala.cast.python.ml.client; +import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DATASET; import static com.ibm.wala.cast.types.AstMethodReference.fnReference; import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction; @@ -123,7 +124,33 @@ private static Set getDataflowSources( int use = eachElementGetInstruction.getUse(0); SSAInstruction def = du.getDef(use); - if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) { + if (def == null) { + logger.warning( + () -> + "Can't find potential tensor iterable definition for use: " + + use + + " of instruction: " + + eachElementGetInstruction + + ". Trying interprocedural analysis..."); + + // Look up the use in the pointer analysis to see if it points to a dataset. + PointerKey usePointerKey = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, use); + + for (InstanceKey ik : pointerAnalysis.getPointsToSet(usePointerKey)) { + if (ik instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = (AllocationSiteInNode) ik; + IClass concreteType = asin.getConcreteType(); + TypeReference reference = concreteType.getReference(); + + if (reference.equals(DATASET)) { + sources.add(src); + logger.info("Added dataflow source from tensor dataset: " + src + "."); + break; + } + } + } + } else if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) { sources.add(src); logger.info("Added dataflow source from tensor iterable: " + src + "."); } diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java new file mode 100644 index 000000000..d906363d0 --- /dev/null +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/types/TensorFlowTypes.java @@ -0,0 +1,18 @@ +package com.ibm.wala.cast.python.ml.types; + +import com.ibm.wala.cast.python.types.PythonTypes; +import com.ibm.wala.types.TypeName; +import com.ibm.wala.types.TypeReference; + +/** + * Types found in the TensorFlow library. + * + * @author Raffi Khatchadourian + */ +public class TensorFlowTypes extends PythonTypes { + + public static final TypeReference DATASET = + TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/data/Dataset")); + + private TensorFlowTypes() {} +} diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset6.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset6.py new file mode 100644 index 000000000..f5fc00ec4 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset6.py @@ -0,0 +1,14 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +def func(ds): + for element in ds: + c = add(element, element) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2) +func(dataset) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py b/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py new file mode 100644 index 000000000..c1f8aebf3 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py @@ -0,0 +1,15 @@ +import tensorflow as tf + + +@tf.function +def add(a, b): + return a + b + + +def func(ds): + for element in ds: + c = add(element, element) + + +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2) +func(dataset)