diff --git a/tests/tf-keras-vis/test_gradcam.py b/tests/tf-keras-vis/test_gradcam.py index f73d721..c795533 100644 --- a/tests/tf-keras-vis/test_gradcam.py +++ b/tests/tf-keras-vis/test_gradcam.py @@ -75,3 +75,14 @@ def test__call__if_penultimate_layer_is_noexist_name(cnn_model): assert False except ValueError: assert True + + +def test__call__if_model_has_only_dense_layer(dense_model): + gradcam = Gradcam(dense_model) + result = gradcam(SmoothedLoss(1), np.random.sample((1, 8, 8, 3)), seek_penultimate_layer=False) + assert result.shape == (1, 8, 8) + try: + gradcam(SmoothedLoss(1), np.random.sample((1, 8, 8, 3))) + assert False + except ValueError: + assert True diff --git a/tf_keras_vis/gradcam.py b/tf_keras_vis/gradcam.py index 968e37b..f9c1576 100644 --- a/tf_keras_vis/gradcam.py +++ b/tf_keras_vis/gradcam.py @@ -13,6 +13,7 @@ def __call__(self, loss, seed_input, penultimate_layer=-1, + seek_penultimate_layer=True, activation_modifier=lambda cam: K.relu(cam), normalize_gradient=True): """Generate a gradient based class activation map (CAM) by using positive gradient of @@ -45,7 +46,8 @@ def __call__(self, if len(seed_inputs) != len(self.model.inputs): raise ValueError('') - penultimate_output_tensor = self._find_penultimate_output(self.model, penultimate_layer) + penultimate_output_tensor = self._find_penultimate_output(self.model, penultimate_layer, + seek_penultimate_layer) model = tf.keras.Model(inputs=self.model.inputs, outputs=self.model.outputs + [penultimate_output_tensor]) with tf.GradientTape() as tape: @@ -70,7 +72,7 @@ def __call__(self, cams = cams[0] return cams - def _find_penultimate_output(self, model, layer): + def _find_penultimate_output(self, model, layer, seek): if not isinstance(layer, tf.keras.layers.Layer): if layer is None: layer = -1 @@ -80,7 +82,7 @@ def _find_penultimate_output(self, model, layer): layer = find_layer(model, lambda l: l.name == layer) else: raise ValueError('Invalid argument. `layer`=', layer) - if layer is not None: + if layer is not None and seek: layer = find_layer(model, lambda l: isinstance(l, Conv), offset=layer) if layer is None: raise ValueError('Unable to determine penultimate `Conv` layer.')