Skip to content

Commit

Permalink
Add seek_penultimate_layer argument to Gradcam#__call__() to fix issues/
Browse files Browse the repository at this point in the history
  • Loading branch information
keisen committed Apr 30, 2020
1 parent 8e3ff1c commit 46d3872
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
11 changes: 11 additions & 0 deletions tests/tf-keras-vis/test_gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,14 @@ def test__call__if_penultimate_layer_is_noexist_name(cnn_model):
assert False
except ValueError:
assert True


def test__call__if_model_has_only_dense_layer(dense_model):
gradcam = Gradcam(dense_model)
result = gradcam(SmoothedLoss(1), np.random.sample((1, 8, 8, 3)), seek_penultimate_layer=False)
assert result.shape == (1, 8, 8)
try:
gradcam(SmoothedLoss(1), np.random.sample((1, 8, 8, 3)))
assert False
except ValueError:
assert True
8 changes: 5 additions & 3 deletions tf_keras_vis/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __call__(self,
loss,
seed_input,
penultimate_layer=-1,
seek_penultimate_layer=True,
activation_modifier=lambda cam: K.relu(cam),
normalize_gradient=True):
"""Generate a gradient based class activation map (CAM) by using positive gradient of
Expand Down Expand Up @@ -45,7 +46,8 @@ def __call__(self,
if len(seed_inputs) != len(self.model.inputs):
raise ValueError('')

penultimate_output_tensor = self._find_penultimate_output(self.model, penultimate_layer)
penultimate_output_tensor = self._find_penultimate_output(self.model, penultimate_layer,
seek_penultimate_layer)
model = tf.keras.Model(inputs=self.model.inputs,
outputs=self.model.outputs + [penultimate_output_tensor])
with tf.GradientTape() as tape:
Expand All @@ -70,7 +72,7 @@ def __call__(self,
cams = cams[0]
return cams

def _find_penultimate_output(self, model, layer):
def _find_penultimate_output(self, model, layer, seek):
if not isinstance(layer, tf.keras.layers.Layer):
if layer is None:
layer = -1
Expand All @@ -80,7 +82,7 @@ def _find_penultimate_output(self, model, layer):
layer = find_layer(model, lambda l: l.name == layer)
else:
raise ValueError('Invalid argument. `layer`=', layer)
if layer is not None:
if layer is not None and seek:
layer = find_layer(model, lambda l: isinstance(l, Conv), offset=layer)
if layer is None:
raise ValueError('Unable to determine penultimate `Conv` layer.')
Expand Down

0 comments on commit 46d3872

Please sign in to comment.