Skip to content

Commit

Permalink
Merge pull request #17 from keisen/features/improve-inefficient-proce…
Browse files Browse the repository at this point in the history
…ssing-to-vectorization

Improve inefficient processing by vectorization
  • Loading branch information
keisen authored Jun 22, 2020
2 parents 1607a84 + d42a75f commit 68ef380
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 63 deletions.
12 changes: 12 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ on:
pull_request:
branches: [ master ]

env:
TF_KERAS_VIS_MAX_STEPS: 2

jobs:
build:

Expand Down Expand Up @@ -37,3 +40,12 @@ jobs:
- name: Test with pytest
run: |
PYTHONPATH=$PWD:$PYTHONPATH py.test
- name: Test attentions.ipynb
run: |
jupyter-nbconvert --ExecutePreprocessor.timeout=600 --to notebook --execute examples/attentions.ipynb
- name: Test visualize_dense_layer.ipynb
run: |
jupyter-nbconvert --ExecutePreprocessor.timeout=600 --to notebook --execute examples/visualize_dense_layer.ipynb
- name: Test visualize_conv_filters.ipynb
run: |
jupyter-nbconvert --ExecutePreprocessor.timeout=600 --to notebook --execute examples/visualize_conv_filters.ipynb
2 changes: 1 addition & 1 deletion .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: Upload Python Package

on:
release:
types: [created]
types: [published]

jobs:
deploy:
Expand Down
25 changes: 0 additions & 25 deletions .travis.yml

This file was deleted.

23 changes: 14 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,32 @@

tf-keras-vis is a visualization toolkit for debugging `tf.keras` models in Tensorflow2.0+.

These features are based on ones of [keras-vis](https://github.com/raghakot/keras-vis), but tf-keras-vis APIs doesn't have compatibility with keras-vis, because we prioritized to get following features for our expriments.

- Support processing multiple images at a time as a batch
- Support tf.keras.Model that has multiple inputs (and, of course, multiple outpus too)
- Allow to utilize optimizers that embeded in tf.keras
- Get faster processing by optimal calculation


## Visualizations

### Visualize Dense Layer

<img src='examples/images/visualize-dense-layer.png' width='600px' />
<img src='https://github.com/keisen/tf-keras-vis/raw/master/examples/images/visualize-dense-layer.png' width='600px' />

### Visualize Convolutional Filer

<img src='examples/images/visualize-filters.png' width='800px' />
<img src='https://github.com/keisen/tf-keras-vis/raw/master/examples/images/visualize-filters.png' width='800px' />

### Saliency Map and GradCAM
### GradCAM

<img src='examples/images/gradcam.png' width='600px' />
<img src='https://github.com/keisen/tf-keras-vis/raw/master/examples/images/gradcam.png' width='600px' />

### Saliency Map (SmoothGrad)

These features are based on ones of [keras-vis](https://github.com/raghakot/keras-vis), but tf-keras-vis APIs doesn't have compatibility with keras-vis, because we prioritized to get following features for our expriments.
<img src='https://github.com/keisen/tf-keras-vis/raw/master/examples/images/smoothgrad.png' width='600px' />

- Support processing multiple images at a time as a batch
- Support tf.keras.Model that has multiple inputs (and, of course, multiple outpus too)
- Allow to use optimizers that embeded in tf.keras


## Requirements
Expand Down Expand Up @@ -77,4 +83,3 @@ Please see [examples/attentions.ipynb](https://github.com/keisen/tf-keras-vis/bl
- [ScoreCAM](https://arxiv.org/pdf/1910.01279.pdf)
- Deep Dream
- Style transfer

55 changes: 46 additions & 9 deletions examples/attentions.ipynb

Large diffs are not rendered by default.

Binary file modified examples/images/gradcam.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/smoothgrad.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/vanilla-saliency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="tf-keras-vis",
version="0.3.2",
version="0.3.3",
author="keisen",
author_email="[email protected]",
description="Neural network visualization toolkit for tf.keras",
Expand Down
41 changes: 39 additions & 2 deletions tests/tf-keras-vis/test_gradcam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest
from tensorflow.keras.layers import Conv2D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Conv2D, Input, Dense, Flatten
from tensorflow.keras.models import Sequential, Model

from tf_keras_vis.gradcam import Gradcam
from tf_keras_vis.utils.losses import SmoothedLoss
Expand All @@ -23,6 +24,17 @@ def cnn_model():
])


@pytest.fixture(scope="function", autouse=True)
def multiple_inputs_cnn_model():
input_a = Input((8, 8, 3))
input_b = Input((10, 10, 3))
x_a = Conv2D(2, 5, activation='relu')(input_a)
x_b = Conv2D(2, 5, activation='relu')(input_b)
x = K.concatenate([Flatten()(x_a), Flatten()(x_b)], axis=-1)
x = Dense(2, activation='softmax')(x)
return Model(inputs=[input_a, input_b], outputs=x)


def test__call__if_loss_is_None(cnn_model):
gradcam = Gradcam(cnn_model)
try:
Expand Down Expand Up @@ -88,3 +100,28 @@ def test__call__if_model_has_only_dense_layer(dense_model):
assert False
except ValueError:
assert True


def test__call__if_model_has_multiple_inputs(multiple_inputs_cnn_model):
gradcam = Gradcam(multiple_inputs_cnn_model)
result = gradcam(
SmoothedLoss(1), [np.random.sample(
(1, 8, 8, 3)), np.random.sample((1, 10, 10, 3))])
assert len(result) == 2
assert result[0].shape == (1, 8, 8)
assert result[1].shape == (1, 10, 10)


def test__call__if_expand_cam_is_False(cnn_model):
gradcam = Gradcam(cnn_model)
result = gradcam(SmoothedLoss(1), np.random.sample((1, 8, 8, 3)), expand_cam=False)
assert result.shape == (1, 6, 6)


def test__call__if_expand_cam_is_False_and_model_has_multiple_inputs(multiple_inputs_cnn_model):
gradcam = Gradcam(multiple_inputs_cnn_model)
result = gradcam(
SmoothedLoss(1), [np.random.sample(
(1, 8, 8, 3)), np.random.sample((1, 10, 10, 3))],
expand_cam=False)
assert result.shape == (1, 6, 6)
4 changes: 2 additions & 2 deletions tf_keras_vis/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def __call__(self,
grads = tape.gradient(loss_values, penultimate_output)
if normalize_gradient:
grads = K.l2_normalize(grads)
weights = K.mean(grads, axis=tuple(np.arange(len(grads.shape))[1:-1]))
cam = np.asarray([np.sum(o * w, axis=-1) for o, w in zip(penultimate_output, weights)])
weights = K.mean(grads, axis=tuple(range(grads.ndim)[1:-1]), keepdims=True)
cam = np.sum(penultimate_output * weights, axis=-1)
if activation_modifier is not None:
cam = activation_modifier(cam)

Expand Down
29 changes: 15 additions & 14 deletions tf_keras_vis/saliency.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,21 @@ def __call__(self,
seed_inputs = self._get_seed_inputs_for_multiple_inputs(seed_input)
# Processing saliency
if smooth_samples > 0:
axes = [tuple(range(1, len(X.shape))) for X in seed_inputs]
sigmas = [
smooth_noise * (np.max(X, axis=axis) - np.min(X, axis=axis))
for X, axis in zip(seed_inputs, axes)
]
total_gradients = (np.zeros_like(X) for X in seed_inputs)
for i in range(check_steps(smooth_samples)):
seed_inputs_plus_noise = [
tf.constant(
np.concatenate([
x + np.random.normal(0., s, (1, ) + x.shape) for x, s in zip(X, sigma)
])) for X, sigma in zip(seed_inputs, sigmas)
]
gradients = self._get_gradients(seed_inputs_plus_noise, losses, gradient_modifier)
smooth_samples = check_steps(smooth_samples)
seed_inputs = (tf.tile(X, (smooth_samples, ) + (1, ) * (X.ndim - 1))
for X in seed_inputs)
seed_inputs = (tf.reshape(X, (smooth_samples, -1) + tuple(X.shape[1:]))
for X in seed_inputs)
seed_inputs = ((X, tuple(range(X.ndim)[1:])) for X in seed_inputs)
seed_inputs = ((X, smooth_noise * (tf.math.reduce_max(X, axis=axis, keepdims=True) -
tf.math.reduce_min(X, axis=axis, keepdims=True)))
for X, axis in seed_inputs)
seed_inputs = (X + np.random.normal(0., sigma, X.shape) for X, sigma in seed_inputs)
seed_inputs = list(seed_inputs)
total_gradients = (np.zeros_like(X[0]) for X in seed_inputs)
for i in range(smooth_samples):
sample = [X[i] for X in seed_inputs]
gradients = self._get_gradients(sample, losses, gradient_modifier)
total_gradients = (total + g for total, g in zip(total_gradients, gradients))
grads = [g / smooth_samples for g in total_gradients]
else:
Expand Down

0 comments on commit 68ef380

Please sign in to comment.