Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

89 losing tensors in datasets #63

Merged
merged 3 commits into from
Dec 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,10 @@ public void testTf2()
testTf2("tf2_testing_decorator8.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator9.py", "returned", 1, 3, 2);
testTf2("tf2_testing_decorator10.py", "returned", 1, 3, 2);
testTf2(
"tf2_test_dataset.py",
"add",
0,
0); // NOTE: Change to testTf2("tf2_test_dataset.py", "add", 2, 3, 2, 3) once
// https://github.com/wala/ML/issues/89 is fixed.
// FIXME: Test tf2_test_dataset.py really has three tensors in its dataset. We are currently
// treating it as one. But, in the literal case, it should be possible to model it like the list
// tests below.
testTf2("tf2_test_dataset.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
Expand Down
23 changes: 23 additions & 0 deletions com.ibm.wala.cast.python.ml/data/tensorflow.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@

<new def="nn" class="Lobject" />
<putfield class="LRoot" field="nn" fieldType="LRoot" ref="x" value="nn" />
<new def="data" class="Lobject" />
<putfield class="LRoot" field="data" fieldType="LRoot" ref="x" value="data" />
<new def="Dataset" class="Lobject" />
<putfield class="LRoot" field="Dataset" fieldType="LRoot" ref="data" value="Dataset" />
<new def="random" class="Lobject" />
<putfield class="LRoot" field="random" fieldType="LRoot" ref="x" value="random" />
<new def="sparse" class="Lobject" />
Expand Down Expand Up @@ -122,6 +126,9 @@
<new def="array_ops" class="Lobject" />
<putfield class="LRoot" field="array_ops" fieldType="LRoot" ref="ops" value="array_ops" />

<new def="data_ops" class="Lobject" />
<putfield class="LRoot" field="data_ops" fieldType="LRoot" ref="ops" value="data_ops" />

<new def="random_ops" class="Lobject" />
<putfield class="LRoot" field="random_ops" fieldType="LRoot" ref="ops" value="random_ops" />

Expand Down Expand Up @@ -167,6 +174,10 @@
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="x" value="ones" />
<putfield class="LRoot" field="ones" fieldType="LRoot" ref="array_ops" value="ones" />

<new def="from_tensor_slices" class="Ltensorflow/functions/from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="Dataset" value="from_tensor_slices" />
<putfield class="LRoot" field="from_tensor_slices" fieldType="LRoot" ref="data_ops" value="from_tensor_slices" />

<new def="zeros" class="Ltensorflow/functions/zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="x" value="zeros" />
<putfield class="LRoot" field="zeros" fieldType="LRoot" ref="array_ops" value="zeros" />
Expand Down Expand Up @@ -399,6 +410,18 @@
</method>
</class>

<class name="from_tensor_slices" allocatable="true">
<!-- "read_dataset" means that this function reads a tensor iterable. -->
<method name="read_dataset" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/data_ops/from_tensor_slices" />
<return value="x" />
</method>
<method name="do" descriptor="()LRoot;" numArgs="2" paramNames="tensors name">
<call class="LRoot" name="read_dataset" descriptor="()LRoot;" type="virtual" arg0="arg0" def="x" />
<return value="x" />
</method>
</class>

<class name="Variable" allocatable="true">
<method name="read_data" descriptor="()LRoot;">
<new def="x" class="Ltensorflow/python/ops/variables/Variable" />
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.types.AstMethodReference.fnReference;

import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
import com.ibm.wala.cast.lsp.AnalysisError;
import com.ibm.wala.cast.python.client.PythonAnalysisEngine;
import com.ibm.wala.cast.python.ml.analysis.TensorTypeAnalysis;
import com.ibm.wala.cast.python.ml.types.TensorType;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.types.AstMethodReference;
import com.ibm.wala.classLoader.CallSiteReference;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IMethod;
import com.ibm.wala.ipa.callgraph.AnalysisOptions;
import com.ibm.wala.ipa.callgraph.CGNode;
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.LocalPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
import com.ibm.wala.ipa.callgraph.propagation.PointsToSetVariable;
import com.ibm.wala.ipa.callgraph.propagation.PropagationCallGraphBuilder;
Expand All @@ -26,13 +35,22 @@
import com.ibm.wala.util.collections.HashSetFactory;
import com.ibm.wala.util.graph.Graph;
import com.ibm.wala.util.graph.impl.SlowSparseNumberedGraph;
import com.ibm.wala.util.intset.OrdinalSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeAnalysis> {

/** A "fake" function name in the summaries that indicates that an API produces a new tensor. */
private static final String TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME = "read_data";

/**
* A "fake" function name in the summaries that indicates that an API produces a tensor iterable.
*/
private static final String TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME = "read_dataset";

private static final Logger logger = Logger.getLogger(PythonTensorAnalysisEngine.class.getName());

private static final MethodReference conv2d =
Expand Down Expand Up @@ -69,7 +87,10 @@ public class PythonTensorAnalysisEngine extends PythonAnalysisEngine<TensorTypeA

private final Map<PointerKey, AnalysisError> errorLog = HashMapFactory.make();

private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVariable> dataflow) {
private static Set<PointsToSetVariable> getDataflowSources(
Graph<PointsToSetVariable> dataflow,
CallGraph callGraph,
PointerAnalysis<InstanceKey> pointerAnalysis) {
Set<PointsToSetVariable> sources = HashSetFactory.make();
for (PointsToSetVariable src : dataflow) {
PointerKey k = src.getPointerKey();
Expand All @@ -81,12 +102,63 @@ private static Set<PointsToSetVariable> getDataflowSources(Graph<PointsToSetVari
SSAInstruction inst = du.getDef(vn);

if (inst instanceof SSAAbstractInvokeInstruction) {
// We potentially have a function call that generates a tensor.
SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst;

if (ni.getCallSite().getDeclaredTarget().getName().toString().equals("read_data")
if (ni.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source " + src + ".");
logger.info("Added dataflow source from tensor generator: " + src + ".");
}
} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;

// Find the potential tensor iterable creation site.
SSAInstruction iterableDef = du.getDef(eachElementGetInstruction.getUse(0));

if (iterableDef instanceof SSAAbstractInvokeInstruction) {
SSAAbstractInvokeInstruction iterableGenInvocationInstruction =
(SSAAbstractInvokeInstruction) iterableDef;

// What function are we calling?
int use = iterableGenInvocationInstruction.getUse(0);
PointerKey pointerKeyForLocal =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(kk.getNode(), use);
OrdinalSet<InstanceKey> pointsToSet =
pointerAnalysis.getPointsToSet(pointerKeyForLocal);

for (InstanceKey ik : pointsToSet) {
if (ik instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) ik;
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();
MethodReference methodReference = fnReference(reference);

// Get the nodes this method calls.
Set<CGNode> iterableNodes = callGraph.getNodes(methodReference);

for (CGNode itNode : iterableNodes)
for (Iterator<CGNode> succNodes = callGraph.getSuccNodes(itNode);
succNodes.hasNext(); ) {
CGNode callee = succNodes.next();
IMethod calledMethod = callee.getMethod();

// Does this method call the sythetic "marker?"
if (calledMethod
.getName()
.toString()
.equals(TENSOR_ITERABLE_SYNTHETIC_FUNCTION_NAME)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
}
}
}
}
}
}
Expand Down Expand Up @@ -165,7 +237,8 @@ public TensorTypeAnalysis performAnalysis(PropagationCallGraphBuilder builder)
SlowSparseNumberedGraph.duplicate(
builder.getPropagationSystem().getFlowGraphIncludingImplicitConstraints());

Set<PointsToSetVariable> sources = getDataflowSources(dataflow);
Set<PointsToSetVariable> sources =
getDataflowSources(dataflow, builder.getCallGraph(), builder.getPointerAnalysis());

TensorType mnistData = TensorType.mnistInput();
Map<PointsToSetVariable, TensorType> init = HashMapFactory.make();
Expand Down