Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interprocedural dataset support #69

Merged
merged 10 commits into from
Jan 4, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ public void testTf2()
testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset6.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
Expand All @@ -218,8 +220,8 @@ public void testTf2()
testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 3);
testTf2("tf2_test_callbacks.py", "replica_fn", 1, 3, 2);
testTf2("tf2_test_callbacks2.py", "replica_fn", 1, 4, 2);
testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 7);
testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 7);
testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 2);
testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 2);
}

private void testTf2(
Expand Down Expand Up @@ -302,27 +304,23 @@ private void testTf2(
methodSignatureToPointerKeys.getOrDefault(functionSignature, Collections.emptySet());

// check tensor parameters.
assertEquals(expectedNumberOfTensorParameters, functionPointerKeys.size());
assertEquals(
expectedNumberOfTensorParameters,
functionPointerKeys.stream().filter(LocalPointerKey::isParameter).count());

// check value numbers.
Set<Integer> actualValueNumberSet =
Set<Integer> actualParameterValueNumberSet =
functionPointerKeys.stream()
.filter(LocalPointerKey::isParameter)
.map(LocalPointerKey::getValueNumber)
.collect(Collectors.toSet());

assertEquals(expectedTensorParameterValueNumbers.length, actualValueNumberSet.size());
assertEquals(expectedTensorParameterValueNumbers.length, actualParameterValueNumberSet.size());
Arrays.stream(expectedTensorParameterValueNumbers)
.forEach(
ev ->
assertTrue(
"Expecting " + actualValueNumberSet + " to contain " + ev + ".",
actualValueNumberSet.contains(ev)));

// get the tensor variables for the function.
Set<TensorVariable> functionTensors =
methodSignatureToTensorVariables.getOrDefault(functionSignature, Collections.emptySet());

// check tensor parameters.
assertEquals(expectedNumberOfTensorParameters, functionTensors.size());
"Expecting " + actualParameterValueNumberSet + " to contain " + ev + ".",
actualParameterValueNumberSet.contains(ev)));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DATASET;
import static com.ibm.wala.cast.types.AstMethodReference.fnReference;

import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
Expand Down Expand Up @@ -123,7 +124,33 @@ private static Set<PointsToSetVariable> getDataflowSources(
int use = eachElementGetInstruction.getUse(0);
SSAInstruction def = du.getDef(use);

if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
if (def == null) {
logger.warning(
() ->
"Can't find potential tensor iterable definition for use: "
+ use
+ " of instruction: "
+ eachElementGetInstruction
+ ". Trying interprocedural analysis...");

// Look up the use in the pointer analysis to see if it points to a dataset.
PointerKey usePointerKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, use);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(usePointerKey)) {
if (ik instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) ik;
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();

if (reference.equals(DATASET)) {
sources.add(src);
logger.info("Added dataflow source from tensor dataset: " + src + ".");
break;
}
}
}
} else if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.ibm.wala.cast.python.ml.types;

import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;

/**
* Types found in the TensorFlow library.
*
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
*/
public class TensorFlowTypes extends PythonTypes {

public static final TypeReference DATASET =
TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/data/Dataset"));

private TensorFlowTypes() {}
}
14 changes: 14 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tensorflow as tf


def add(a, b):
return a + b


def func(ds):
for element in ds:
c = add(element, element)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)
func(dataset)
15 changes: 15 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf


@tf.function
def add(a, b):
return a + b


def func(ds):
for element in ds:
c = add(element, element)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)
func(dataset)