Skip to content

Commit

Permalink
Merge pull request #227 from carlini/remove_keras
Browse files Browse the repository at this point in the history
Remove keras from tutorials
  • Loading branch information
npapernot authored Aug 17, 2017
2 parents 65f6917 + fc2f4cc commit 6dfd0f9
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 242 deletions.
69 changes: 21 additions & 48 deletions tutorials/mnist_blackbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,47 +6,33 @@
import numpy as np
from six.moves import xrange

import keras
from keras import backend
from keras.utils.np_utils import to_categorical
from keras.models import Sequential
from keras.layers import Dense, Flatten, Activation, Dropout

import logging
import tensorflow as tf
from tensorflow.python.platform import app
from tensorflow.python.platform import flags

from cleverhans.utils_keras import cnn_model
from cleverhans.utils_mnist import data_mnist
from cleverhans.utils import to_categorical
from cleverhans.utils import set_log_level
from cleverhans.utils_tf import model_train, model_eval, batch_eval
from cleverhans.attacks import FastGradientMethod
from cleverhans.attacks_tf import jacobian_graph, jacobian_augmentation
from cleverhans.utils_keras import KerasModelWrapper
from cleverhans.utils import set_log_level

from tutorials.tutorial_models import make_basic_cnn, MLP
from tutorials.tutorial_models import Flatten, Linear, ReLU, Softmax

FLAGS = flags.FLAGS


def setup_tutorial():
"""
Helper function to check correct configuration of tf and keras for tutorial
Helper function to check correct configuration of tf for tutorial
:return: True if setup checks completed
"""

# Set TF random seed to improve reproducibility
tf.set_random_seed(1234)

if not hasattr(backend, "tf"):
raise RuntimeError("This tutorial requires keras to be configured"
" to use the TensorFlow backend.")

# Image dimensions ordering should follow the Theano convention
if keras.backend.image_dim_ordering() != 'tf':
keras.backend.set_image_dim_ordering('tf')
print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' "
"to 'th', temporarily setting to 'tf'")

return True


Expand All @@ -69,7 +55,7 @@ def prep_bbox(sess, x, y, X_train, Y_train, X_test, Y_test,
"""

# Define TF model graph (for the black-box model)
model = cnn_model()
model = make_basic_cnn()
predictions = model(x)
print("Defined TensorFlow model graph.")

Expand All @@ -94,35 +80,25 @@ def prep_bbox(sess, x, y, X_train, Y_train, X_test, Y_test,

def substitute_model(img_rows=28, img_cols=28, nb_classes=10):
"""
Defines the model architecture to be used by the substitute.
Defines the model architecture to be used by the substitute. Use
the example model interface.
:param img_rows: number of rows in input
:param img_cols: number of columns in input
:param nb_classes: number of classes in output
:return: keras model
:return: tensorflow model
"""
model = Sequential()

# Find out the input shape ordering
if keras.backend.image_dim_ordering() == 'th':
input_shape = (1, img_rows, img_cols)
else:
input_shape = (img_rows, img_cols, 1)
input_shape = (None, img_rows, img_cols, 1)

# Define a fully connected model (it's different than the black-box)
layers = [Flatten(input_shape=input_shape),
Dense(200),
Activation('relu'),
Dropout(0.5),
Dense(200),
Activation('relu'),
Dropout(0.5),
Dense(nb_classes),
Activation('softmax')]

for layer in layers:
model.add(layer)
layers = [Flatten(),
Linear(200),
ReLU(),
Linear(200),
ReLU(),
Linear(nb_classes),
Softmax()]

return model
return MLP(layers, input_shape)


def train_sub(sess, x, y, bbox_preds, X_sub, Y_sub, nb_classes,
Expand Down Expand Up @@ -200,7 +176,6 @@ def mnist_blackbox(train_start=0, train_end=60000, test_start=0,
* black-box model accuracy on adversarial examples transferred
from the substitute model
"""
keras.layers.core.K.set_learning_phase(0)

# Set logging level to see debug information
set_log_level(logging.DEBUG)
Expand All @@ -211,9 +186,8 @@ def mnist_blackbox(train_start=0, train_end=60000, test_start=0,
# Perform tutorial setup
assert setup_tutorial()

# Create TF session and set as Keras backend session
# Create TF session
sess = tf.Session()
keras.backend.set_session(sess)

# Get MNIST data
X_train, Y_train, X_test, Y_test = data_mnist(train_start=train_start,
Expand Down Expand Up @@ -254,8 +228,7 @@ def mnist_blackbox(train_start=0, train_end=60000, test_start=0,

# Initialize the Fast Gradient Sign Method (FGSM) attack object.
fgsm_par = {'eps': 0.3, 'ord': np.inf, 'clip_min': 0., 'clip_max': 1.}
wrap = KerasModelWrapper(model_sub)
fgsm = FastGradientMethod(wrap, sess=sess)
fgsm = FastGradientMethod(model_sub, sess=sess)

# Craft adversarial examples using the substitute
eval_params = {'batch_size': batch_size}
Expand Down
25 changes: 7 additions & 18 deletions tutorials/mnist_tutorial_cw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import keras
import numpy as np
from six.moves import xrange
import tensorflow as tf
Expand All @@ -13,10 +12,11 @@
import logging
import os
from cleverhans.attacks import CarliniWagnerL2
from cleverhans.utils import grid_visual, AccuracyReport, set_log_level
from cleverhans.utils_keras import cnn_model, KerasModelWrapper
from cleverhans.utils import pair_visual, grid_visual, AccuracyReport
from cleverhans.utils import set_log_level
from cleverhans.utils_mnist import data_mnist
from cleverhans.utils_tf import model_train, model_eval, tf_model_load
from tutorials.tutorial_models import make_basic_cnn

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -51,22 +51,12 @@ def mnist_tutorial_cw(train_start=0, train_end=60000, test_start=0,
img_cols = 28
channels = 1

# Disable Keras learning phase since we will be serving through tensorflow
keras.layers.core.K.set_learning_phase(0)

# Set TF random seed to improve reproducibility
tf.set_random_seed(1234)

# Image dimensions ordering should follow the TensorFlow convention
if keras.backend.image_dim_ordering() != 'tf':
keras.backend.set_image_dim_ordering('tf')
print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' "
"to 'th', temporarily setting to 'tf'")

# Create TF session and set as Keras backend session
# Create TF session
sess = tf.Session()
keras.backend.set_session(sess)
print("Created TensorFlow session and set Keras backend.")
print("Created TensorFlow session.")

set_log_level(logging.DEBUG)

Expand All @@ -81,7 +71,7 @@ def mnist_tutorial_cw(train_start=0, train_end=60000, test_start=0,
y = tf.placeholder(tf.float32, shape=(None, 10))

# Define TF model graph
model = cnn_model()
model = make_basic_cnn()
preds = model(x)
print("Defined TensorFlow model graph.")

Expand Down Expand Up @@ -120,8 +110,7 @@ def mnist_tutorial_cw(train_start=0, train_end=60000, test_start=0,
print("This could take some time ...")

# Instantiate a CW attack object
wrap = KerasModelWrapper(model)
cw = CarliniWagnerL2(wrap, back='tf', sess=sess)
cw = CarliniWagnerL2(model, back='tf', sess=sess)

idxs = [np.where(np.argmax(Y_test, axis=1) == i)[0][0] for i in range(10)]
if targeted:
Expand Down
19 changes: 4 additions & 15 deletions tutorials/mnist_tutorial_jsma.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import print_function
from __future__ import unicode_literals

import keras
import numpy as np
from six.moves import xrange
import tensorflow as tf
Expand All @@ -17,6 +16,7 @@
from cleverhans.utils_mnist import data_mnist
from cleverhans.utils_tf import model_train, model_eval, model_argmax
from cleverhans.utils_keras import KerasModelWrapper, cnn_model
from tutorials.tutorial_models import make_basic_cnn

FLAGS = flags.FLAGS

Expand Down Expand Up @@ -47,22 +47,12 @@ def mnist_tutorial_jsma(train_start=0, train_end=60000, test_start=0,
img_cols = 28
channels = 1

# Disable Keras learning phase since we will be serving through tensorflow
keras.layers.core.K.set_learning_phase(0)

# Set TF random seed to improve reproducibility
tf.set_random_seed(1234)

# Image dimensions ordering should follow the TensorFlow convention
if keras.backend.image_dim_ordering() != 'tf':
keras.backend.set_image_dim_ordering('tf')
print("INFO: '~/.keras/keras.json' sets 'image_dim_ordering' "
"to 'th', temporarily setting to 'tf'")

# Create TF session and set as Keras backend session
sess = tf.Session()
keras.backend.set_session(sess)
print("Created TensorFlow session and set Keras backend.")
print("Created TensorFlow session.")

set_log_level(logging.DEBUG)

Expand All @@ -77,7 +67,7 @@ def mnist_tutorial_jsma(train_start=0, train_end=60000, test_start=0,
y = tf.placeholder(tf.float32, shape=(None, 10))

# Define TF model graph
model = cnn_model()
model = make_basic_cnn()
preds = model(x)
print("Defined TensorFlow model graph.")

Expand Down Expand Up @@ -118,8 +108,7 @@ def mnist_tutorial_jsma(train_start=0, train_end=60000, test_start=0,
grid_viz_data = np.zeros(grid_shape, dtype='f')

# Instantiate a SaliencyMapMethod attack object
wrap = KerasModelWrapper(model)
jsma = SaliencyMapMethod(wrap, back='tf', sess=sess)
jsma = SaliencyMapMethod(model, back='tf', sess=sess)
jsma_params = {'theta': 1., 'gamma': 0.1,
'clip_min': 0., 'clip_max': 1.,
'y_target': None}
Expand Down
Loading

0 comments on commit 6dfd0f9

Please sign in to comment.