diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java index cf1f42164..c2cae666b 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflowModel.java @@ -196,12 +196,15 @@ public void testTf2() testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2); testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2); testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2); - testTf2( - "tf2_test_dataset.py", - "add", - 0, - 0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once - // https://github.com/wala/ML/issues/89 is fixed. + // FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently + // treating it as one. But, in the literal case, it should be possible to model it like the list + // tests below. + testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3); + testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3); + testTf2("tf2_test_tensor_list2.py", "add", 0, 2); + testTf2("tf2_test_tensor_list3.py", "add", 0, 2); + testTf2("tf2_test_tensor_list4.py", "add", 0, 0); + testTf2("tf2_test_tensor_list5.py", "add", 0, 2); testTf2("tf2_test_model_call.py", "SequentialModel.__call__", 1, 4, 3); testTf2("tf2_test_model_call2.py", "SequentialModel.call", 1, 4, 3); testTf2("tf2_test_model_call3.py", "SequentialModel.call", 1, 4, 3); diff --git a/com.ibm.wala.cast.python.ml/data/tensorflow.xml b/com.ibm.wala.cast.python.ml/data/tensorflow.xml index 1d2ccbc0d..b5860ff3a 100644 --- a/com.ibm.wala.cast.python.ml/data/tensorflow.xml +++ b/com.ibm.wala.cast.python.ml/data/tensorflow.xml @@ -41,6 +41,10 @@ + + + + @@ -122,6 +126,9 @@ + + + @@ -167,6 +174,10 @@ + + + + @@ -399,6 +410,18 @@ + + + + + + + + + + + + diff --git a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java index 97901bcef..0b2dc6d06 100644 --- a/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java +++ b/com.ibm.wala.cast.python.ml/source/com/ibm/wala/cast/python/ml/client/PythonTensorAnalysisEngine.java @@ -1,5 +1,8 @@ package com.ibm.wala.cast.python.ml.client; +import static com.ibm.wala.cast.types.AstMethodReference.fnReference; + +import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction; import com.ibm.wala.cast.lsp.AnalysisError; import com.ibm.wala.cast.python.client.PythonAnalysisEngine; import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis; @@ -7,9 +10,15 @@ import com.ibm.wala.cast.python.types.PythonTypes; import com.ibm.wala.cast.types.AstMethodReference; import com.ibm.wala.classLoader.CallSiteReference; +import com.ibm.wala.classLoader.IClass; +import com.ibm.wala.classLoader.IMethod; import com.ibm.wala.ipa.callgraph.AnalysisOptions; import com.ibm.wala.ipa.callgraph.CGNode; +import com.ibm.wala.ipa.callgraph.CallGraph; +import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode; +import com.ibm.wala.ipa.callgraph.propagation.InstanceKey; import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey; +import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis; import com.ibm.wala.ipa.callgraph.propagation.PointerKey; import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable; import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder; @@ -26,6 +35,7 @@ import com.ibm.wala.util.collections.HashSetFactory; import com.ibm.wala.util.graph.Graph; import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph; +import com.ibm.wala.util.intset.OrdinalSet; import java.util.Iterator; import java.util.Map; import java.util.Set; @@ -33,6 +43,14 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine { + /** A "fake" function name in the summaries that indicates that an API produces a new tensor. */ + private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data"; + + /** + * A "fake" function name in the summaries that indicates that an API produces a tensor iterable. + */ + private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset"; + private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName()); private static final MethodReference conv2d = @@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine errorLog = HashMapFactory.make(); - private static Set getDataflowSources(Graph dataflow) { + private static Set getDataflowSources( + Graph dataflow, + CallGraph callGraph, + PointerAnalysis pointerAnalysis) { Set sources = HashSetFactory.make(); for (PointsToSetVariable src : dataflow) { PointerKey k = src.getPointerKey(); @@ -77,16 +98,33 @@ private static Set getDataflowSources(Graph getDataflowSources(Graph pointerAnalysis) { + if (instruction instanceof SSAAbstractInvokeInstruction) { + SSAAbstractInvokeInstruction invocationInstruction = + (SSAAbstractInvokeInstruction) instruction; + + if (invocationInstruction.getNumberOfUses() > 0) { + // What function are we calling? + int use = invocationInstruction.getUse(0); + PointerKey pointerKeyForLocal = + pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, use); + OrdinalSet pointsToSet = pointerAnalysis.getPointsToSet(pointerKeyForLocal); + + for (InstanceKey ik : pointsToSet) { + if (ik instanceof AllocationSiteInNode) { + AllocationSiteInNode asin = (AllocationSiteInNode) ik; + IClass concreteType = asin.getConcreteType(); + TypeReference reference = concreteType.getReference(); + MethodReference methodReference = fnReference(reference); + + // Get the nodes this method calls. + Set iterableNodes = callGraph.getNodes(methodReference); + + for (CGNode itNode : iterableNodes) + for (Iterator succNodes = callGraph.getSuccNodes(itNode); + succNodes.hasNext(); ) { + CGNode callee = succNodes.next(); + IMethod calledMethod = callee.getMethod(); + + // Does this method call the synthetic "marker?" + if (calledMethod + .getName() + .toString() + .equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) { + return true; + } + } + } + } + } + } + return false; + } + @FunctionalInterface interface SourceCallHandler { void handleCall(CGNode src, SSAAbstractInvokeInstruction call); @@ -165,7 +260,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder) SlowSparseNumberedGraph.duplicate( builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints()); - Set sources = getDataflowSources(dataflow); + Set sources = + getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis()); TensorType mnistData = TensorType.mnistInput(); Map init = HashMapFactory.make(); diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py new file mode 100644 index 000000000..794a54cfb --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +list = [tf.ones([1, 2]), tf.ones([2, 2])] + +for element in list: + c = add(element, element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py new file mode 100644 index 000000000..33e191933 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list2.py @@ -0,0 +1,10 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +list = [tf.ones([1, 2]), tf.ones([2, 2])] + +c = add(list, list) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py new file mode 100644 index 000000000..b884ea582 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list3.py @@ -0,0 +1,14 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +list = list() + +list.append(tf.ones([1, 2])) +list.append(tf.ones([2, 2])) + +for element in list: + c = add(element, element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py new file mode 100644 index 000000000..5abef44a4 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list4.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +my_list = list([1, 2]) + +for element in my_list: + c = add(element, element) diff --git a/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py new file mode 100644 index 000000000..75782b54b --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/tf2_test_tensor_list5.py @@ -0,0 +1,11 @@ +import tensorflow as tf + + +def add(a, b): + return a + b + + +my_list = list([tf.ones([1, 2]), tf.ones([2, 2])]) + +for element in my_list: + c = add(element, element) diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java index 870c3491d..5c0bf6308 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonSSAPropagationCallGraphBuilder.java @@ -12,10 +12,13 @@ import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder; import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey; +import com.ibm.wala.cast.ir.ssa.AstPropertyRead; +import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction; import com.ibm.wala.cast.python.ipa.summaries.BuiltinFunctions.BuiltinFunction; import com.ibm.wala.cast.python.ir.PythonLanguage; import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor; import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction; +import com.ibm.wala.cast.python.ssa.PythonPropertyRead; import com.ibm.wala.cast.python.types.PythonTypes; import com.ibm.wala.classLoader.IClass; import com.ibm.wala.classLoader.IField; @@ -37,6 +40,7 @@ import com.ibm.wala.ssa.SSAArrayStoreInstruction; import com.ibm.wala.ssa.SSABinaryOpInstruction; import com.ibm.wala.ssa.SSAGetInstruction; +import com.ibm.wala.ssa.SSAInstruction; import com.ibm.wala.ssa.SymbolTable; import com.ibm.wala.types.FieldReference; import com.ibm.wala.types.TypeReference; @@ -50,9 +54,13 @@ import java.util.Arrays; import java.util.Collection; import java.util.Map; +import java.util.logging.Logger; public class PythonSSAPropagationCallGraphBuilder extends AstSSAPropagationCallGraphBuilder { + private static final Logger logger = + Logger.getLogger(PythonSSAPropagationCallGraphBuilder.class.getName()); + public PythonSSAPropagationCallGraphBuilder( IClassHierarchy cha, AnalysisOptions options, @@ -171,6 +179,42 @@ public String toString() { super.visitGet(instruction); } + @Override + public void visitPropertyRead(AstPropertyRead instruction) { + super.visitPropertyRead(instruction); + + if (instruction instanceof PythonPropertyRead) { + PythonPropertyRead ppr = (PythonPropertyRead) instruction; + SSAInstruction memberRefDef = du.getDef(ppr.getMemberRef()); + + if (memberRefDef != null && memberRefDef instanceof EachElementGetInstruction) { + // most likely a for each "property." + final PointerKey memberRefKey = this.getPointerKeyForLocal(ppr.getMemberRef()); + + // for each def of the property read. + for (int i = 0; i < ppr.getNumberOfDefs(); i++) { + PointerKey defKey = this.getPointerKeyForLocal(ppr.getDef(i)); + + // add an assignment constraint straight away as the traversal variable won't have a + // non-empty points-to set but still may be used for a dataflow analysis. + if (this.system.newConstraint(defKey, assignOperator, memberRefKey)) + logger.fine( + () -> + "Added new system constraint for global read from: " + + defKey + + " to: " + + memberRefKey + + " for instruction: " + + instruction + + "."); + else + logger.fine( + () -> "No constraint added for global read in instruction: " + instruction + "."); + } + } + } + } + @Override public void visitPythonInvoke(PythonInvokeInstruction inst) { visitInvokeInternal(inst, new DefaultInvariantComputer());