diff --git a/inception_v3.py b/inception_v3.py index 0616ff0..4fca301 100644 --- a/inception_v3.py +++ b/inception_v3.py @@ -125,7 +125,7 @@ def InceptionV3(include_top=True, weights='imagenet', x = conv2d_bn(x, 192, 3, 3, border_mode='valid') x = MaxPooling2D((3, 3), strides=(2, 2))(x) - # mixed: 35 x 35 x 256 + # mixed 0, 1, 2: 35 x 35 x 256 for i in range(3): branch1x1 = conv2d_bn(x, 64, 1, 1) @@ -140,9 +140,10 @@ def InceptionV3(include_top=True, weights='imagenet', (3, 3), strides=(1, 1), border_mode='same')(x) branch_pool = conv2d_bn(branch_pool, 32, 1, 1) x = merge([branch1x1, branch5x5, branch3x3dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed' + str(i)) - # mixed3: 17 x 17 x 768 + # mixed 3: 17 x 17 x 768 branch3x3 = conv2d_bn(x, 384, 3, 3, subsample=(2, 2), border_mode='valid') branch3x3dbl = conv2d_bn(x, 64, 1, 1) @@ -152,9 +153,10 @@ def InceptionV3(include_top=True, weights='imagenet', branch_pool = MaxPooling2D((3, 3), strides=(2, 2))(x) x = merge([branch3x3, branch3x3dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed3') - # mixed4: 17 x 17 x 768 + # mixed 4: 17 x 17 x 768 branch1x1 = conv2d_bn(x, 192, 1, 1) branch7x7 = conv2d_bn(x, 128, 1, 1) @@ -170,10 +172,11 @@ def InceptionV3(include_top=True, weights='imagenet', branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(x) branch_pool = conv2d_bn(branch_pool, 192, 1, 1) x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed4') - # mixed5: 17 x 17 x 768 - for _ in range(2): + # mixed 5, 6: 17 x 17 x 768 + for i in range(2): branch1x1 = conv2d_bn(x, 192, 1, 1) branch7x7 = conv2d_bn(x, 160, 1, 1) @@ -190,9 +193,10 @@ def InceptionV3(include_top=True, weights='imagenet', (3, 3), strides=(1, 1), border_mode='same')(x) branch_pool = conv2d_bn(branch_pool, 192, 1, 1) x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed' + str(5 + i)) - # mixed7: 17 x 17 x 768 + # mixed 7: 17 x 17 x 768 branch1x1 = conv2d_bn(x, 192, 1, 1) branch7x7 = conv2d_bn(x, 192, 1, 1) @@ -208,9 +212,10 @@ def InceptionV3(include_top=True, weights='imagenet', branch_pool = AveragePooling2D((3, 3), strides=(1, 1), border_mode='same')(x) branch_pool = conv2d_bn(branch_pool, 192, 1, 1) x = merge([branch1x1, branch7x7, branch7x7dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed7') - # mixed8: 8 x 8 x 1280 + # mixed 8: 8 x 8 x 1280 branch3x3 = conv2d_bn(x, 192, 1, 1) branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, subsample=(2, 2), border_mode='valid') @@ -223,17 +228,19 @@ def InceptionV3(include_top=True, weights='imagenet', branch_pool = AveragePooling2D((3, 3), strides=(2, 2))(x) x = merge([branch3x3, branch7x7x3, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed8') - # mixed9: 8 x 8 x 2048 - for _ in range(2): + # mixed 9: 8 x 8 x 2048 + for i in range(2): branch1x1 = conv2d_bn(x, 320, 1, 1) branch3x3 = conv2d_bn(x, 384, 1, 1) branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3) branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1) branch3x3 = merge([branch3x3_1, branch3x3_2], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed9_' + str(i)) branch3x3dbl = conv2d_bn(x, 448, 1, 1) branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3) @@ -246,7 +253,8 @@ def InceptionV3(include_top=True, weights='imagenet', (3, 3), strides=(1, 1), border_mode='same')(x) branch_pool = conv2d_bn(branch_pool, 192, 1, 1) x = merge([branch1x1, branch3x3, branch3x3dbl, branch_pool], - mode='concat', concat_axis=channel_axis) + mode='concat', concat_axis=channel_axis, + name='mixed' + str(9 + i)) if include_top: # Classification block @@ -259,7 +267,6 @@ def InceptionV3(include_top=True, weights='imagenet', # load weights if weights == 'imagenet': - print('K.image_dim_ordering:', K.image_dim_ordering()) if K.image_dim_ordering() == 'th': if include_top: weights_path = get_file('inception_v3_weights_th_dim_ordering_th_kernels.h5',