Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Caltech101 #255

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fuel/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"""
from fuel.converters import adult
from fuel.converters import binarized_mnist
from fuel.converters import caltech101
from fuel.converters import caltech101_silhouettes
from fuel.converters import cifar10
from fuel.converters import cifar100
Expand All @@ -23,6 +24,7 @@
all_converters = (
('adult', adult.fill_subparser),
('binarized_mnist', binarized_mnist.fill_subparser),
('caltech101', caltech101.fill_subparser),
('caltech101_silhouettes', caltech101_silhouettes.fill_subparser),
('cifar10', cifar10.fill_subparser),
('cifar100', cifar100.fill_subparser),
Expand Down
196 changes: 196 additions & 0 deletions fuel/converters/caltech101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import os
import h5py

import numpy
import scipy.misc

from six.moves import range

from fuel.converters.base import fill_hdf5_file, MissingInputFiles

CATEGORIES = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

101 lines just for the CATEGORIES tuple is quite long, could we condense that?

'Leopards',
'emu',
'hedgehog',
'binocular',
'cougar_body',
'buddha',
'Faces_easy',
'beaver',
'windsor_chair',
'yin_yang',
'anchor',
'pagoda',
'mayfly',
'flamingo_head',
'headphone',
'joshua_tree',
'wrench',
'platypus',
'dollar_bill',
'dalmatian',
'mandolin',
'llama',
'electric_guitar',
'panda',
'lamp',
'pyramid',
'kangaroo',
'strawberry',
'stop_sign',
'flamingo',
'gerenuk',
'crayfish',
'ketch',
'crocodile_head',
'chandelier',
'cellphone',
'brain',
'car_side',
'ferry',
'nautilus',
'BACKGROUND_Google',
'metronome',
'water_lilly',
'dolphin',
'euphonium',
'crocodile',
'Faces',
'sunflower',
'garfield',
'soccer_ball',
'stapler',
'scorpion',
'wheelchair',
'saxophone',
'starfish',
'lotus',
'okapi',
'octopus',
'hawksbill',
'chair',
'crab',
'menorah',
'helicopter',
'accordion',
'rhino',
'ant',
'bass',
'bonsai',
'butterfly',
'ewer',
'pizza',
'umbrella',
'revolver',
'airplanes',
'grand_piano',
'trilobite',
'cannon',
'wild_cat',
'inline_skate',
'watch',
'pigeon',
'rooster',
'cup',
'dragonfly',
'barrel',
'ceiling_fan',
'lobster',
'minaret',
'schooner',
'cougar_face',
'Motorbikes',
'laptop',
'elephant',
'sea_horse',
'snoopy',
'brontosaurus',
'gramophone',
'camera',
'stegosaurus',
'tick',
'scissors',
'ibis')

NUM_TRAIN = 20
NUM_TEST = 10


def read_image(imfile):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very fond of resizing the images to 256x256, as it hard-codes part of the pre-processing right into the converter. I think we should leave them as variable-length data and instead make the resizing part of the dataset's default transformer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't any transformer that resizes images currently. Unless I'm wrong.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

im = scipy.misc.imresize(scipy.misc.imread(imfile), (256, 256))
if im.ndim == 2:
return im.reshape(1, 256, 256)
else:
return numpy.rollaxis(im, 2, 0)


def convert_silhouettes(directory, output_directory,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_silhouettes -> convert_caltech101?

output_file=None):
""" Convert the CalTech 101 Datasets.

Parameters
----------
directory : str
Directory in which the required input files reside.
output_file : str
Where to save the converted dataset.

"""
if output_file is None:
output_file = 'caltech101.hdf5'
output_file = os.path.join(output_directory, output_file)

input_dir = '101_ObjectCategories'
input_dir = os.path.join(directory, input_dir)

if not os.path.isdir(input_dir):
raise MissingInputFiles('Required files missing', [input_dir])

with h5py.File(output_file, mode="w") as h5file:
train_features = numpy.empty(
(len(CATEGORIES) * NUM_TRAIN, 3, 256, 256), dtype='uint8')
test_features = numpy.empty((len(CATEGORIES) * NUM_TEST, 3, 256, 256),
dtype='uint8')

for i, c in enumerate(CATEGORIES):
for j in range(NUM_TRAIN):
imfile = os.path.join(input_dir, c,
'image_{:04d}.jpg'.format(j + 1))
train_features[i * NUM_TRAIN + j] = read_image(imfile)
for j in range(NUM_TEST):
imfile = os.path.join(
input_dir, c, 'image_{:04d}.jpg'.format(j + NUM_TRAIN + 1))
test_features[i * NUM_TEST + j] = read_image(imfile)

train_targets = numpy.repeat(numpy.arange(len(CATEGORIES)), NUM_TRAIN)
train_targets = train_targets.reshape(-1, 1)
test_targets = numpy.repeat(numpy.arange(len(CATEGORIES)), NUM_TEST)
test_targets = test_targets.reshape(-1, 1)

data = (
('train', 'features', train_features),
('train', 'targets', train_targets),
('test', 'features', test_features),
('test', 'targets', test_targets),
)
fill_hdf5_file(h5file, data)

for i, label in enumerate(('batch', 'channel', 'height', 'width')):
h5file['features'].dims[i].label = label

for i, label in enumerate(('batch', 'index')):
h5file['targets'].dims[i].label = label

return (output_file,)


def fill_subparser(subparser):
"""Sets up a subparser to convert CalTech101 Silhouettes Database files.

Parameters
----------
subparser : :class:`argparse.ArgumentParser`
Subparser handling the `caltech101_silhouettes` command.

"""
return convert_silhouettes
1 change: 1 addition & 0 deletions fuel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fuel.datasets.binarized_mnist import BinarizedMNIST
from fuel.datasets.cifar10 import CIFAR10
from fuel.datasets.cifar100 import CIFAR100
from fuel.datasets.caltech101 import CalTech101
from fuel.datasets.caltech101_silhouettes import CalTech101Silhouettes
from fuel.datasets.iris import Iris
from fuel.datasets.mnist import MNIST
Expand Down
27 changes: 27 additions & 0 deletions fuel/datasets/caltech101.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
from fuel.utils import find_in_data_path
from fuel.datasets import H5PYDataset
from fuel.transformers.defaults import uint8_pixels_to_floatX


class CalTech101(H5PYDataset):
u"""CalTech 101 dataset.

Parameters
----------
which_sets : tuple of str
Which split to load. Valid values are 'train' and 'test'.

"""
filename = 'caltech101.hdf5'
default_transformers = uint8_pixels_to_floatX(('features'),)

def __init__(self, which_sets, load_in_memory=True, **kwargs):
super(CalTech101, self).__init__(file_or_path=self.data_path,
which_sets=which_sets,
load_in_memory=load_in_memory,
**kwargs)

@property
def data_path(self):
return find_in_data_path(self.filename)