diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
index cf1f42164..c2cae666b 100644
--- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
+++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java
@@ -196,12 +196,15 @@ public void testTf2()
testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2);
- testTf2(
- "tf2_test_dataset.py",
- "add",
- 0,
- 0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
- // https://github.com/wala/ML/issues/89 is fixed.
+ // FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently
+ // treating it as one. But, in the literal case, it should be possible to model it like the list
+ // tests below.
+ testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
+ testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
+ testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
+ testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
+ testTf2("tf2_test_tensor_list4.py", "add", 0, 0);
+ testTf2("tf2_test_tensor_list5.py", "add", 0, 2);
testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 3);
testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 3);
testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 3);
diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
index 1d2ccbc0d..b5860ff3a 100644
--- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml
+++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml
@@ -41,6 +41,10 @@
+
+
+
+
@@ -122,6 +126,9 @@
+
+
+
@@ -167,6 +174,10 @@
+
+
+
+
@@ -399,6 +410,18 @@
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
index 97901bcef..0b2dc6d06 100644
--- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
+++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java
@@ -1,5 +1,8 @@
package com.ibm.wala.cast.python.ml.client;
+import static com.ibm.wala.cast.types.AstMethodReference.fnReference;
+
+import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
@@ -7,9 +10,15 @@
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
+import com.ibm.wala.classLoader.IClass;
+import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
+import com.ibm.wala.ipa.callgraph.CallGraph;
+import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
+import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
+import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
@@ -26,6 +35,7 @@
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
+import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
@@ -33,6 +43,14 @@
public class PythonTensorAnalysisEngine extends PythonAnalysisEngine {
+ /** A "fake" function name in the summaries that indicates that an API produces a new tensor. */
+ private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data";
+
+ /**
+ * A "fake" function name in the summaries that indicates that an API produces a tensor iterable.
+ */
+ private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset";
+
private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName());
private static final MethodReference conv2d =
@@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine errorLog = HashMapFactory.make();
- private static Set getDataflowSources(Graph dataflow) {
+ private static Set getDataflowSources(
+ Graph dataflow,
+ CallGraph callGraph,
+ PointerAnalysis pointerAnalysis) {
Set sources = HashSetFactory.make();
for (PointsToSetVariable src : dataflow) {
PointerKey k = src.getPointerKey();
@@ -77,16 +98,33 @@ private static Set getDataflowSources(Graph getDataflowSources(Graph pointerAnalysis) {
+ if (instruction instanceof SSAAbstractInvokeInstruction) {
+ SSAAbstractInvokeInstruction invocationInstruction =
+ (SSAAbstractInvokeInstruction) instruction;
+
+ if (invocationInstruction.getNumberOfUses() > 0) {
+ // What function are we calling?
+ int use = invocationInstruction.getUse(0);
+ PointerKey pointerKeyForLocal =
+ pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, use);
+ OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(pointerKeyForLocal);
+
+ for (InstanceKey ik : pointsToSet) {
+ if (ik instanceof AllocationSiteInNode) {
+ AllocationSiteInNode asin = (AllocationSiteInNode) ik;
+ IClass concreteType = asin.getConcreteType();
+ TypeReference reference = concreteType.getReference();
+ MethodReference methodReference = fnReference(reference);
+
+ // Get the nodes this method calls.
+ Set iterableNodes = callGraph.getNodes(methodReference);
+
+ for (CGNode itNode : iterableNodes)
+ for (Iterator succNodes = callGraph.getSuccNodes(itNode);
+ succNodes.hasNext(); ) {
+ CGNode callee = succNodes.next();
+ IMethod calledMethod = callee.getMethod();
+
+ // Does this method call the synthetic "marker?"
+ if (calledMethod
+ .getName()
+ .toString()
+ .equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) {
+ return true;
+ }
+ }
+ }
+ }
+ }
+ }
+ return false;
+ }
+
@FunctionalInterface
interface SourceCallHandler {
void handleCall(CGNode src, SSAAbstractInvokeInstruction call);
@@ -165,7 +260,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
SlowSparseNumberedGraph.duplicate(
builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());
- Set sources = getDataflowSources(dataflow);
+ Set sources =
+ getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());
TensorType mnistData = TensorType.mnistInput();
Map init = HashMapFactory.make();
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py
new file mode 100644
index 000000000..794a54cfb
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py
@@ -0,0 +1,11 @@
+import tensorflow as tf
+
+
+def add(a, b):
+ return a + b
+
+
+list = [tf.ones([1, 2]), tf.ones([2, 2])]
+
+for element in list:
+ c = add(element, element)
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py
new file mode 100644
index 000000000..33e191933
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py
@@ -0,0 +1,10 @@
+import tensorflow as tf
+
+
+def add(a, b):
+ return a + b
+
+
+list = [tf.ones([1, 2]), tf.ones([2, 2])]
+
+c = add(list, list)
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py
new file mode 100644
index 000000000..b884ea582
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py
@@ -0,0 +1,14 @@
+import tensorflow as tf
+
+
+def add(a, b):
+ return a + b
+
+
+list = list()
+
+list.append(tf.ones([1, 2]))
+list.append(tf.ones([2, 2]))
+
+for element in list:
+ c = add(element, element)
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py
new file mode 100644
index 000000000..5abef44a4
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py
@@ -0,0 +1,11 @@
+import tensorflow as tf
+
+
+def add(a, b):
+ return a + b
+
+
+my_list = list([1, 2])
+
+for element in my_list:
+ c = add(element, element)
diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py
new file mode 100644
index 000000000..75782b54b
--- /dev/null
+++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py
@@ -0,0 +1,11 @@
+import tensorflow as tf
+
+
+def add(a, b):
+ return a + b
+
+
+my_list = list([tf.ones([1, 2]), tf.ones([2, 2])])
+
+for element in my_list:
+ c = add(element, element)
diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java
index 870c3491d..5c0bf6308 100644
--- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java
+++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java
@@ -12,10 +12,13 @@
import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey;
+import com.ibm.wala.cast.ir.ssa.AstPropertyRead;
+import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions.BuiltinFunction;
import com.ibm.wala.cast.python.ir.PythonLanguage;
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
+import com.ibm.wala.cast.python.ssa.PythonPropertyRead;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IField;
@@ -37,6 +40,7 @@
import com.ibm.wala.ssa.SSAArrayStoreInstruction;
import com.ibm.wala.ssa.SSABinaryOpInstruction;
import com.ibm.wala.ssa.SSAGetInstruction;
+import com.ibm.wala.ssa.SSAInstruction;
import com.ibm.wala.ssa.SymbolTable;
import com.ibm.wala.types.FieldReference;
import com.ibm.wala.types.TypeReference;
@@ -50,9 +54,13 @@
import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
+import java.util.logging.Logger;
public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder {
+ private static final Logger logger =
+ Logger.getLogger(PythonSSAPropagationCallGraphBuilder.class.getName());
+
public PythonSSAPropagationCallGraphBuilder(
IClassHierarchy cha,
AnalysisOptions options,
@@ -171,6 +179,42 @@ public String toString() {
super.visitGet(instruction);
}
+ @Override
+ public void visitPropertyRead(AstPropertyRead instruction) {
+ super.visitPropertyRead(instruction);
+
+ if (instruction instanceof PythonPropertyRead) {
+ PythonPropertyRead ppr = (PythonPropertyRead) instruction;
+ SSAInstruction memberRefDef = du.getDef(ppr.getMemberRef());
+
+ if (memberRefDef != null && memberRefDef instanceof EachElementGetInstruction) {
+ // most likely a for each "property."
+ final PointerKey memberRefKey = this.getPointerKeyForLocal(ppr.getMemberRef());
+
+ // for each def of the property read.
+ for (int i = 0; i < ppr.getNumberOfDefs(); i++) {
+ PointerKey defKey = this.getPointerKeyForLocal(ppr.getDef(i));
+
+ // add an assignment constraint straight away as the traversal variable won't have a
+ // non-empty points-to set but still may be used for a dataflow analysis.
+ if (this.system.newConstraint(defKey, assignOperator, memberRefKey))
+ logger.fine(
+ () ->
+ "Added new system constraint for global read from: "
+ + defKey
+ + " to: "
+ + memberRefKey
+ + " for instruction: "
+ + instruction
+ + ".");
+ else
+ logger.fine(
+ () -> "No constraint added for global read in instruction: " + instruction + ".");
+ }
+ }
+ }
+ }
+
@Override
public void visitPythonInvoke(PythonInvokeInstruction inst) {
visitInvokeInternal(inst, new DefaultInvariantComputer());