From 9574624fa6ec7cdcb3917138baf613eb6fc68efd Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Fri, 19 Jul 2024 10:50:44 -0400 Subject: [PATCH] Deal with multiple possible callables (#121) * Return `null` when there are multiple possible callables. * Add test to exercise call string imprecision. Based on the call string length. See https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. * Expect the test to fail. In the past, we could add 0's to the parameters, but since we are not enforcing the existing of the node in the CG, we can no longer do that. Still, this test should now fail if https://github.com/wala/ML/issues/207 is fixed. --- .../python/ml/test/TestTensorflow2Model.java | 35 +++++++++++++++ com.ibm.wala.cast.python.test/.pydevproject | 1 + .../data/proj66/src/__init__.py | 1 + .../data/proj66/src/tf2_test_model_call5b.py | 9 ++++ .../data/proj66/tf2_test_model_call5.py | 44 +++++++++++++++++++ .../data/proj66/tf2_test_model_call5a.py | 44 +++++++++++++++++++ ...nstanceMethodTrampolineTargetSelector.java | 24 +++++++++- 7 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 com.ibm.wala.cast.python.test/data/proj66/src/__init__.py create mode 100644 com.ibm.wala.cast.python.test/data/proj66/src/tf2_test_model_call5b.py create mode 100644 com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5.py create mode 100644 com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5a.py diff --git a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java index 086a80fb0..6f62c7862 100644 --- a/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java +++ b/com.ibm.wala.cast.python.ml.test/source/com/ibm/wala/cast/python/ml/test/TestTensorflow2Model.java @@ -1181,6 +1181,41 @@ public void testModelCall4() test("tf2_test_model_call4.py", "SequentialModel.__call__", 1, 1, 3); } + /** + * Test call string imprecision as described in + * https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. This should fail due + * to https://github.com/wala/ML/issues/207. + */ + @Test(expected = java.lang.AssertionError.class) + public void testModelCall5() + throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { + test( + new String[] { + "proj66/src/tf2_test_model_call5b.py", + "proj66/tf2_test_model_call5.py", + "proj66/tf2_test_model_call5a.py" + }, + "tf2_test_model_call5.py", + "SequentialModel.__call__", + "proj66", + 1, + 1, + 3); + + test( + new String[] { + "proj66/src/tf2_test_model_call5b.py", + "proj66/tf2_test_model_call5.py", + "proj66/tf2_test_model_call5a.py" + }, + "tf2_test_model_call5a.py", + "SequentialModel.__call__", + "proj66", + 1, + 1, + 3); + } + @Test public void testModelAttributes() throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException { diff --git a/com.ibm.wala.cast.python.test/.pydevproject b/com.ibm.wala.cast.python.test/.pydevproject index 47c85339d..eeda77dea 100644 --- a/com.ibm.wala.cast.python.test/.pydevproject +++ b/com.ibm.wala.cast.python.test/.pydevproject @@ -20,5 +20,6 @@ /${PROJECT_DIR_NAME}/data/proj52 /${PROJECT_DIR_NAME}/data/proj55 /${PROJECT_DIR_NAME}/data/proj56 + /${PROJECT_DIR_NAME}/data/proj66 diff --git a/com.ibm.wala.cast.python.test/data/proj66/src/__init__.py b/com.ibm.wala.cast.python.test/data/proj66/src/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/proj66/src/__init__.py @@ -0,0 +1 @@ + diff --git a/com.ibm.wala.cast.python.test/data/proj66/src/tf2_test_model_call5b.py b/com.ibm.wala.cast.python.test/data/proj66/src/tf2_test_model_call5b.py new file mode 100644 index 000000000..59bea179a --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/proj66/src/tf2_test_model_call5b.py @@ -0,0 +1,9 @@ +# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. + + +def f(m, d): + return m.predict(d) + + +def g(m, d): + return f(m, d) diff --git a/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5.py b/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5.py new file mode 100644 index 000000000..64274e909 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5.py @@ -0,0 +1,44 @@ +# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. + +import tensorflow as tf +from src.tf2_test_model_call5b import g + +# Create an override model to classify pictures + + +class SequentialModel(tf.keras.Model): + + def __init__(self, **kwargs): + super(SequentialModel, self).__init__(**kwargs) + + self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) + + # Add a lot of small layers + num_layers = 100 + self.my_layers = [ + tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers) + ] + + self.dropout = tf.keras.layers.Dropout(0.2) + self.dense_2 = tf.keras.layers.Dense(10) + + def __call__(self, x): + print("Raffi 1") + x = self.flatten(x) + + for layer in self.my_layers: + x = layer(x) + + x = self.dropout(x) + x = self.dense_2(x) + + return x + + def predict(self, x): + return self(x) + + +input_data = tf.random.uniform([20, 28, 28]) + +model = SequentialModel() +result = g(model, input_data) diff --git a/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5a.py b/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5a.py new file mode 100644 index 000000000..842edd5b8 --- /dev/null +++ b/com.ibm.wala.cast.python.test/data/proj66/tf2_test_model_call5a.py @@ -0,0 +1,44 @@ +# Test https://github.com/wala/WALA/discussions/1417#discussioncomment-10085680. + +import tensorflow as tf +from src.tf2_test_model_call5b import g + +# Create an override model to classify pictures + + +class SequentialModel(tf.keras.Model): + + def __init__(self, **kwargs): + super(SequentialModel, self).__init__(**kwargs) + + self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28)) + + # Add a lot of small layers + num_layers = 100 + self.my_layers = [ + tf.keras.layers.Dense(64, activation="relu") for n in range(num_layers) + ] + + self.dropout = tf.keras.layers.Dropout(0.2) + self.dense_2 = tf.keras.layers.Dense(10) + + def __call__(self, x): + print("Raffi 2") + x = self.flatten(x) + + for layer in self.my_layers: + x = layer(x) + + x = self.dropout(x) + x = self.dense_2(x) + + return x + + def predict(self, x): + return self(x) + + +input_data = tf.random.uniform([20, 28, 28]) + +model = SequentialModel() +result = g(model, input_data) diff --git a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java index 5b9788705..41ed1ee38 100644 --- a/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java +++ b/com.ibm.wala.cast.python/source/com/ibm/wala/cast/python/ipa/callgraph/PythonInstanceMethodTrampolineTargetSelector.java @@ -42,6 +42,7 @@ import com.ibm.wala.util.collections.HashMapFactory; import com.ibm.wala.util.collections.Pair; import com.ibm.wala.util.intset.OrdinalSet; +import java.util.HashMap; import java.util.Map; import java.util.logging.Logger; @@ -222,6 +223,8 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr PointerKey receiver = pkf.getPointerKeyForLocal(caller, call.getUse(0)); OrdinalSet objs = builder.getPointerAnalysis().getPointsToSet(receiver); + Map instanceToCallable = new HashMap<>(); + for (InstanceKey o : objs) { AllocationSiteInNode instanceKey = getAllocationSiteInNode(o); if (instanceKey != null) { @@ -253,10 +256,29 @@ private IClass getCallable(CGNode caller, IClassHierarchy cha, PythonInvokeInstr LOGGER.info("Applying callable workaround for https://github.com/wala/ML/issues/118."); } - if (callable != null) return callable; + if (callable != null) { + if (instanceToCallable.containsKey(instanceKey)) + throw new IllegalStateException("Exisitng mapping found for: " + instanceKey); + + IClass previousValue = instanceToCallable.put(instanceKey, callable); + assert previousValue == null : "Not expecting a previous mapping."; + } } } + // if there's only one possible option. + if (instanceToCallable.values().size() == 1) { + IClass callable = instanceToCallable.values().iterator().next(); + assert callable != null : "Callable should be non-null."; + return callable; + } + + // if we have multiple candidates. + if (instanceToCallable.values().size() > 1) + // we cannot accurately select one. + LOGGER.warning( + "Multiple (" + instanceToCallable.values().size() + ") callable targets found."); + return null; }