diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
index 95ed16e89..c30a8e189 100644
--- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
+++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
@@ -199,12 +199,10 @@ public void testTf2()
testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2);
- testTf2(
- "tf2_test_dataset.py",
- "add",
- 0,
- 0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
- // https://github.com/wala/ML/issues/89 is fixed.
+ // FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently
+ // treating it as one. But, in the literal case, it should be possible to model it like the list
+ // tests below.
+ testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
index 1d2ccbc0d..b5860ff3a 100644
--- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml
+++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
@@ -41,6 +41,10 @@
+
+
+
+
@@ -122,6 +126,9 @@
+
+
+
@@ -167,6 +174,10 @@
+
+
+
+
@@ -399,6 +410,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
index 97901bcef..bbca9c2e5 100644
--- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
+++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
@@ -1,5 +1,8 @@
package com.ibm.wala.cast.python.ml.client;
+import static com.ibm.wala.cast.types.AstMethodReference.fnReference;
+
+import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
@@ -7,9 +10,15 @@
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
+import com.ibm.wala.classLoader.IClass;
+import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
+import com.ibm.wala.ipa.callgraph.CallGraph;
+import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
+import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
+import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
@@ -26,6 +35,7 @@
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
+import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
@@ -33,6 +43,14 @@
public class PythonTensorAnalysisEngine extends PythonAnalysisEngine {
+ /** A "fake" function name in the summaries that indicates that an API produces a new tensor. */
+ private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data";
+
+ /**
+ * A "fake" function name in the summaries that indicates that an API produces a tensor iterable.
+ */
+ private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset";
+
private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName());
private static final MethodReference conv2d =
@@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine errorLog = HashMapFactory.make();
- private static Set getDataflowSources(Graph dataflow) {
+ private static Set getDataflowSources(
+ Graph dataflow,
+ CallGraph callGraph,
+ PointerAnalysis pointerAnalysis) {
Set sources = HashSetFactory.make();
for (PointsToSetVariable src : dataflow) {
PointerKey k = src.getPointerKey();
@@ -81,12 +102,63 @@ private static Set getDataflowSources(Graph pointsToSet =
+ pointerAnalysis.getPointsToSet(pointerKeyForLocal);
+
+ for (InstanceKey ik : pointsToSet) {
+ if (ik instanceof AllocationSiteInNode) {
+ AllocationSiteInNode asin = (AllocationSiteInNode) ik;
+ IClass concreteType = asin.getConcreteType();
+ TypeReference reference = concreteType.getReference();
+ MethodReference methodReference = fnReference(reference);
+
+ // Get the nodes this method calls.
+ Set iterableNodes = callGraph.getNodes(methodReference);
+
+ for (CGNode itNode : iterableNodes)
+ for (Iterator succNodes = callGraph.getSuccNodes(itNode);
+ succNodes.hasNext(); ) {
+ CGNode callee = succNodes.next();
+ IMethod calledMethod = callee.getMethod();
+
+ // Does this method call the sythetic "marker?"
+ if (calledMethod
+ .getName()
+ .toString()
+ .equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) {
+ sources.add(src);
+ logger.info("Added dataflow source from tensor iterable: " + src + ".");
+ }
+ }
+ }
+ }
}
}
}
@@ -165,7 +237,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
SlowSparseNumberedGraph.duplicate(
builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());
- Set sources = getDataflowSources(dataflow);
+ Set sources =
+ getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());
TensorType mnistData = TensorType.mnistInput();
Map init = HashMapFactory.make();