Skip to content

Commit

Permalink
Add interprocedural dataset support (#69)
Browse files Browse the repository at this point in the history
* Warn if we can't find the iterable definition.

* Fix comment.

* Actually check parameters.

* Add interprocedural dataset test.

* Fallback to interprocedal analysis for datasets.

If the dataset iterable can't be found using intraprodecudral analysis, use interprocedural.

* Apply spotless.

* Add a hybrid dataset test case.

* Actually check the parameter value numbers.
  • Loading branch information
khatchad committed Jan 29, 2024
1 parent d0aa2fa commit 7eed945
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ public void testTf2()
testTf2("tf2_test_dataset3.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset4.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset5.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset6.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_dataset7.py", "add", 2, 2, 2, 3);
testTf2("tf2_test_tensor_list.py", "add", 2, 3, 2, 3);
testTf2("tf2_test_tensor_list2.py", "add", 0, 2);
testTf2("tf2_test_tensor_list3.py", "add", 0, 2);
Expand All @@ -218,8 +220,8 @@ public void testTf2()
testTf2("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 4, 3);
testTf2("tf2_test_callbacks.py", "replica_fn", 1, 3, 2);
testTf2("tf2_test_callbacks2.py", "replica_fn", 1, 4, 2);
testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 7);
testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 7);
testTf2("tensorflow_gan_tutorial.py", "train_step", 1, 10, 2);
testTf2("tensorflow_gan_tutorial2.py", "train_step", 1, 10, 2);
}

private void testTf2(
Expand Down Expand Up @@ -302,27 +304,23 @@ private void testTf2(
methodSignatureToPointerKeys.getOrDefault(functionSignature, Collections.emptySet());

// check tensor parameters.
assertEquals(expectedNumberOfTensorParameters, functionPointerKeys.size());
assertEquals(
expectedNumberOfTensorParameters,
functionPointerKeys.stream().filter(LocalPointerKey::isParameter).count());

// check value numbers.
Set<Integer> actualValueNumberSet =
Set<Integer> actualParameterValueNumberSet =
functionPointerKeys.stream()
.filter(LocalPointerKey::isParameter)
.map(LocalPointerKey::getValueNumber)
.collect(Collectors.toSet());

assertEquals(expectedTensorParameterValueNumbers.length, actualValueNumberSet.size());
assertEquals(expectedTensorParameterValueNumbers.length, actualParameterValueNumberSet.size());
Arrays.stream(expectedTensorParameterValueNumbers)
.forEach(
ev ->
assertTrue(
"Expecting " + actualValueNumberSet + " to contain " + ev + ".",
actualValueNumberSet.contains(ev)));

// get the tensor variables for the function.
Set<TensorVariable> functionTensors =
methodSignatureToTensorVariables.getOrDefault(functionSignature, Collections.emptySet());

// check tensor parameters.
assertEquals(expectedNumberOfTensorParameters, functionTensors.size());
"Expecting " + actualParameterValueNumberSet + " to contain " + ev + ".",
actualParameterValueNumberSet.contains(ev)));
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.ibm.wala.cast.python.ml.client;

import static com.ibm.wala.cast.python.ml.types.TensorFlowTypes.DATASET;
import static com.ibm.wala.cast.types.AstMethodReference.fnReference;

import com.ibm.wala.cast.ir.ssa.EachElementGetInstruction;
Expand Down Expand Up @@ -123,7 +124,33 @@ private static Set<PointsToSetVariable> getDataflowSources(
int use = eachElementGetInstruction.getUse(0);
SSAInstruction def = du.getDef(use);

if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
if (def == null) {
logger.warning(
() ->
"Can't find potential tensor iterable definition for use: "
+ use
+ " of instruction: "
+ eachElementGetInstruction
+ ". Trying interprocedural analysis...");

// Look up the use in the pointer analysis to see if it points to a dataset.
PointerKey usePointerKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, use);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(usePointerKey)) {
if (ik instanceof AllocationSiteInNode) {
AllocationSiteInNode asin = (AllocationSiteInNode) ik;
IClass concreteType = asin.getConcreteType();
TypeReference reference = concreteType.getReference();

if (reference.equals(DATASET)) {
sources.add(src);
logger.info("Added dataflow source from tensor dataset: " + src + ".");
break;
}
}
}
} else if (definesTensorIterable(def, localPointerKeyNode, callGraph, pointerAnalysis)) {
sources.add(src);
logger.info("Added dataflow source from tensor iterable: " + src + ".");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package com.ibm.wala.cast.python.ml.types;

import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.types.TypeName;
import com.ibm.wala.types.TypeReference;

/**
* Types found in the TensorFlow library.
*
* @author <a href="mailto:[email protected]">Raffi Khatchadourian</a>
*/
public class TensorFlowTypes extends PythonTypes {

public static final TypeReference DATASET =
TypeReference.findOrCreate(pythonLoader, TypeName.findOrCreate("Ltensorflow/data/Dataset"));

private TensorFlowTypes() {}
}
14 changes: 14 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset6.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import tensorflow as tf


def add(a, b):
return a + b


def func(ds):
for element in ds:
c = add(element, element)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)
func(dataset)
15 changes: 15 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import tensorflow as tf


@tf.function
def add(a, b):
return a + b


def func(ds):
for element in ds:
c = add(element, element)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3]).shuffle(3).batch(2)
func(dataset)

0 comments on commit 7eed945

Please sign in to comment.