-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
332 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
# -*- coding: utf-8 -*- | ||
'''Inception V3 model for Keras. | ||
Note that the ImageNet weights provided are from a model that had not fully converged. | ||
Inception v3 should be able to reach 6.9% top-5 error, but our model | ||
only gets to 7.8% (same as a fully-converged ResNet 50). | ||
For comparison, VGG16 only gets to 9.9%, quite a bit worse. | ||
Also, do note that the input image format for this model is different than for | ||
other models (299x299 instead of 224x224), and that the input preprocessing function | ||
is also different. | ||
# Reference: | ||
- [Rethinking the Inception Architecture for Computer Vision](http://arxiv.org/abs/1512.00567) | ||
''' | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
import warnings | ||
|
||
from keras.models import Model | ||
from keras.layers import Flatten, Dense, Input, BatchNormalization, merge | ||
from keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D | ||
from keras.preprocessing import image | ||
from keras.utils.layer_utils import convert_all_kernels_in_model | ||
from keras.utils.data_utils import get_file | ||
from keras import backend as K | ||
from imagenet_utils import decode_predictions | ||
|
||
|
||
TH_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_th_dim_ordering_th_kernels.h5' | ||
TF_WEIGHTS_PATH = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels.h5' | ||
TH_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_th_dim_ordering_th_kernels_notop.h5' | ||
TF_WEIGHTS_PATH_NO_TOP = 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5' | ||
|
||
|
||
def conv2d_bn(x, nb_filter, nb_row, nb_col, | ||
border_mode='same', subsample=(1, 1), | ||
name=None): | ||
'''Utility function to apply conv + BN. | ||
''' | ||
if name is not None: | ||
bn_name = name + '_bn' | ||
conv_name = name + '_conv' | ||
else: | ||
bn_name = None | ||
conv_name = None | ||
if K.image_dim_ordering() == 'th': | ||
bn_axis = 1 | ||
else: | ||
bn_axis = 3 | ||
x = Convolution2D(nb_filter, nb_row, nb_col, | ||
subsample=subsample, | ||
activation='relu', | ||
border_mode=border_mode, | ||
name=conv_name)(x) | ||
x = BatchNormalization(axis=bn_axis, name=bn_name)(x) | ||
return x | ||
|
||
|
||
def InceptionV3(include_top=True, weights='imagenet', | ||
input_tensor=None): | ||
'''Instantiate the Inception v3 architecture, | ||
optionally loading weights pre-trained | ||
on ImageNet. Note that when using TensorFlow, | ||
for best performance you should set | ||
`image_dim_ordering="tf"` in your Keras config | ||
at ~/.keras/keras.json. | ||
The model and the weights are compatible with both | ||
TensorFlow and Theano. The dimension ordering | ||
convention used by the model is the one | ||
specified in your Keras config file. | ||
Note that the default input image size for this model is 299x299. | ||
# Arguments | ||
include_top: whether to include the 3 fully-connected | ||
layers at the top of the network. | ||
weights: one of `None` (random initialization) | ||
or "imagenet" (pre-training on ImageNet). | ||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) | ||
to use as image input for the model. | ||
# Returns | ||
A Keras model instance. | ||
''' | ||
if weights not in {'imagenet', None}: | ||
raise ValueError('The `weights` argument should be either ' | ||
'`None` (random initialization) or `imagenet` ' | ||
'(pre-training on ImageNet).') | ||
# Determine proper input shape | ||
if K.image_dim_ordering() == 'th': | ||
if include_top: | ||
input_shape = (3, 299, 299) | ||
else: | ||
input_shape = (3, None, None) | ||
else: | ||
if include_top: | ||
input_shape = (299, 299, 3) | ||
else: | ||
input_shape = (None, None, 3) | ||
|
||
if input_tensor is None: | ||
img_input = Input(shape=input_shape) | ||
else: | ||
if not K.is_keras_tensor(input_tensor): | ||
img_input = Input(tensor=input_tensor) | ||
else: | ||
img_input = input_tensor | ||
|
||
if K.image_dim_ordering() == 'th': | ||
channel_axis = 1 | ||
else: | ||
channel_axis = 3 | ||
|
||
x = conv2d_bn(img_input, 32, 3, 3, subsample=(2, 2), border_mode='valid') | ||
x = conv2d_bn(x, 32, 3, 3, border_mode='valid') | ||
x = conv2d_bn(x, 64, 3, 3) | ||
x = MaxPooling2D((3, 3), strides=(2, 2))(x) | ||
|
||
x = conv2d_bn(x, 80, 1, 1, border_mode='valid') | ||
x = conv2d_bn(x, 192, 3, 3, border_mode='valid') | ||
x = MaxPooling2D((3, 3), strides=(2, 2))(x) | ||
|
||
# mixed: 35 x 35 x 256 | ||
for i in range(3): | ||
branch1x1 = conv2d_bn(x, 64, 1, 1) | ||
|
||
branch5x5 = conv2d_bn(x, 48, 1, 1) | ||
branch5x5 = conv2d_bn(branch5x5, 64, 5, 5) | ||
|
||
branch3x3dbl = conv2d_bn(x, 64, 1, 1) | ||
branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) | ||
branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) | ||
|
||
branch_pool = AveragePooling2D( | ||
(3, 3), strides=(1, 1), border_mode='same')(x) | ||
branch_pool = conv2d_bn(branch_pool, 32, 1, 1) | ||
x = merge([branch1x1, branch5x5, branch3x3dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed3: 17 x 17 x 768 | ||
branch3x3 = conv2d_bn(x, 384, 3, 3, subsample=(2, 2), border_mode='valid') | ||
|
||
branch3x3dbl = conv2d_bn(x, 64, 1, 1) | ||
branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3) | ||
branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3, | ||
subsample=(2, 2), border_mode='valid') | ||
|
||
branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x) | ||
x = merge([branch3x3, branch3x3dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed4: 17 x 17 x 768 | ||
branch1x1 = conv2d_bn(x, 192, 1, 1) | ||
|
||
branch7x7 = conv2d_bn(x, 128, 1, 1) | ||
branch7x7 = conv2d_bn(branch7x7, 128, 1, 7) | ||
branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) | ||
|
||
branch7x7dbl = conv2d_bn(x, 128, 1, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) | ||
|
||
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(x) | ||
branch_pool = conv2d_bn(branch_pool, 192, 1, 1) | ||
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed5: 17 x 17 x 768 | ||
for _ in range(2): | ||
branch1x1 = conv2d_bn(x, 192, 1, 1) | ||
|
||
branch7x7 = conv2d_bn(x, 160, 1, 1) | ||
branch7x7 = conv2d_bn(branch7x7, 160, 1, 7) | ||
branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) | ||
|
||
branch7x7dbl = conv2d_bn(x, 160, 1, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) | ||
|
||
branch_pool = AveragePooling2D( | ||
(3, 3), strides=(1, 1), border_mode='same')(x) | ||
branch_pool = conv2d_bn(branch_pool, 192, 1, 1) | ||
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed7: 17 x 17 x 768 | ||
branch1x1 = conv2d_bn(x, 192, 1, 1) | ||
|
||
branch7x7 = conv2d_bn(x, 192, 1, 1) | ||
branch7x7 = conv2d_bn(branch7x7, 192, 1, 7) | ||
branch7x7 = conv2d_bn(branch7x7, 192, 7, 1) | ||
|
||
branch7x7dbl = conv2d_bn(x, 160, 1, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1) | ||
branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7) | ||
|
||
branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(x) | ||
branch_pool = conv2d_bn(branch_pool, 192, 1, 1) | ||
x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed8: 8 x 8 x 1280 | ||
branch3x3 = conv2d_bn(x, 192, 1, 1) | ||
branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, | ||
subsample=(2, 2), border_mode='valid') | ||
|
||
branch7x7x3 = conv2d_bn(x, 192, 1, 1) | ||
branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7) | ||
branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1) | ||
branch7x7x3 = conv2d_bn(branch7x7x3, 192, 3, 3, | ||
subsample=(2, 2), border_mode='valid') | ||
|
||
branch_pool = AveragePooling2D((3, 3), strides=(2, 2))(x) | ||
x = merge([branch3x3, branch7x7x3, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
# mixed9: 8 x 8 x 2048 | ||
for _ in range(2): | ||
branch1x1 = conv2d_bn(x, 320, 1, 1) | ||
|
||
branch3x3 = conv2d_bn(x, 384, 1, 1) | ||
branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) | ||
branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) | ||
branch3x3 = merge([branch3x3_1, branch3x3_2], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
branch3x3dbl = conv2d_bn(x, 448, 1, 1) | ||
branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) | ||
branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3) | ||
branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1) | ||
branch3x3dbl = merge([branch3x3dbl_1, branch3x3dbl_2], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
branch_pool = AveragePooling2D( | ||
(3, 3), strides=(1, 1), border_mode='same')(x) | ||
branch_pool = conv2d_bn(branch_pool, 192, 1, 1) | ||
x = merge([branch1x1, branch3x3, branch3x3dbl, branch_pool], | ||
mode='concat', concat_axis=channel_axis) | ||
|
||
if include_top: | ||
# Classification block | ||
x = AveragePooling2D((8, 8), strides=(8, 8), name='avg_pool')(x) | ||
x = Flatten(name='flatten')(x) | ||
x = Dense(1000, activation='softmax', name='predictions')(x) | ||
|
||
# Create model | ||
model = Model(img_input, x) | ||
|
||
# load weights | ||
if weights == 'imagenet': | ||
print('K.image_dim_ordering:', K.image_dim_ordering()) | ||
if K.image_dim_ordering() == 'th': | ||
if include_top: | ||
weights_path = get_file('inception_v3_weights_th_dim_ordering_th_kernels.h5', | ||
TH_WEIGHTS_PATH, | ||
cache_subdir='models', | ||
md5_hash='b3baf3070cc4bf476d43a2ea61b0ca5f') | ||
else: | ||
weights_path = get_file('inception_v3_weights_th_dim_ordering_th_kernels_notop.h5', | ||
TH_WEIGHTS_PATH_NO_TOP, | ||
cache_subdir='models', | ||
md5_hash='79aaa90ab4372b4593ba3df64e142f05') | ||
model.load_weights(weights_path) | ||
if K.backend() == 'tensorflow': | ||
warnings.warn('You are using the TensorFlow backend, yet you ' | ||
'are using the Theano ' | ||
'image dimension ordering convention ' | ||
'(`image_dim_ordering="th"`). ' | ||
'For best performance, set ' | ||
'`image_dim_ordering="tf"` in ' | ||
'your Keras config ' | ||
'at ~/.keras/keras.json.') | ||
convert_all_kernels_in_model(model) | ||
else: | ||
if include_top: | ||
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels.h5', | ||
TF_WEIGHTS_PATH, | ||
cache_subdir='models', | ||
md5_hash='fe114b3ff2ea4bf891e9353d1bbfb32f') | ||
else: | ||
weights_path = get_file('inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5', | ||
TF_WEIGHTS_PATH_NO_TOP, | ||
cache_subdir='models', | ||
md5_hash='2f3609166de1d967d1a481094754f691') | ||
model.load_weights(weights_path) | ||
if K.backend() == 'theano': | ||
convert_all_kernels_in_model(model) | ||
return model | ||
|
||
|
||
def preprocess_input(x): | ||
x /= 255. | ||
x -= 0.5 | ||
x *= 2. | ||
return x | ||
|
||
|
||
if __name__ == '__main__': | ||
model = InceptionV3(include_top=True, weights='imagenet') | ||
|
||
img_path = 'elephant.jpg' | ||
img = image.load_img(img_path, target_size=(299, 299)) | ||
x = image.img_to_array(img) | ||
x = np.expand_dims(x, axis=0) | ||
|
||
x = preprocess_input(x) | ||
|
||
preds = model.predict(x) | ||
print('Predicted:', decode_predictions(preds)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters