Skip to content

Commit

Permalink
Deleted unnessesary source_name argument, removed blank line, added t…
Browse files Browse the repository at this point in the history
…ests for input validation.
  • Loading branch information
Markus Nagel authored and aukejw committed Dec 4, 2015
1 parent ee76725 commit c0213e6
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 25 deletions.
40 changes: 17 additions & 23 deletions fuel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,34 +944,28 @@ class OneHotEncoding(SourcewiseTransformer):
The data stream.
num_classes : int
The number of classes.
source_name : str, default 'targets'
The name of the source that will be transformed.
"""
def __init__(self, data_stream, num_classes, source_name='targets',
**kwargs):
self.num_classes = num_classes
self.source_name = source_name

def __init__(self, data_stream, num_classes, **kwargs):
super(OneHotEncoding, self).__init__(
data_stream, data_stream.produces_examples, **kwargs)
self.num_classes = num_classes

def transform_source_example(self, source_example, source_name):
if source_name == self.source_name:
assert source_example <= self.num_classes
output = numpy.zeros((1, self.num_classes))
output[0, source_example] = 1
return output
else:
return source_example
if source_example >= self.num_classes:
raise ValueError("source_example ({}) must be lower than "
"num_classes ({})".format(source_example,
num_classes))
output = numpy.zeros((1, self.num_classes))
output[0, source_example] = 1
return output

def transform_source_batch(self, source_batch, source_name):
if source_name == self.source_name:
assert numpy.max(source_batch) < self.num_classes
output = numpy.zeros((source_batch.shape[0], self.num_classes),
dtype=source_batch.dtype)
for i in range(self.num_classes):
output[source_batch[:, 0] == i, i] = 1
return output
else:
return source_batch
if numpy.max(source_batch) >= self.num_classes:
raise ValueError("all entries in source_batch must be lower than "
"num_classes ({})".format(num_classes))
output = numpy.zeros((source_batch.shape[0], self.num_classes),
dtype=source_batch.dtype)
for i in range(self.num_classes):
output[source_batch[:, 0] == i, i] = 1
return output
20 changes: 18 additions & 2 deletions tests/transformers/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,27 +706,43 @@ def test_one_hot_examples(self):
DataStream(IndexableDataset(self.data),
iteration_scheme=SequentialExampleScheme(4)),
num_classes=4,
source_name='targets')
which_sources='targets')
assert_equal(
list(wrapper.get_epoch_iterator()),
[(numpy.ones((2, 2)), numpy.array([[1, 0, 0, 0]])),
(numpy.ones((2, 2)), numpy.array([[0, 1, 0, 0]])),
(numpy.ones((2, 2)), numpy.array([[0, 0, 1, 0]])),
(numpy.ones((2, 2)), numpy.array([[0, 0, 0, 1]]))])

def test_one_hot_examples_invalid_inputs(self):
wrapper = OneHotEncoding(
DataStream(IndexableDataset(self.data),
iteration_scheme=SequentialExampleScheme(4)),
num_classes=2,
which_sources='targets')
assert_raises(ValueError, list, wrapper.get_epoch_iterator())

def test_one_hot_batches(self):
wrapper = OneHotEncoding(
DataStream(IndexableDataset(self.data),
iteration_scheme=SequentialScheme(4, 2)),
num_classes=4,
source_name='targets')
which_sources='targets')
assert_equal(
list(wrapper.get_epoch_iterator()),
[(numpy.ones((2, 2, 2)),
numpy.array([[1, 0, 0, 0], [0, 1, 0, 0]])),
(numpy.ones((2, 2, 2)),
numpy.array([[0, 0, 1, 0], [0, 0, 0, 1]]))])

def test_one_hot_batches_invalid_input(self):
wrapper = OneHotEncoding(
DataStream(IndexableDataset(self.data),
iteration_scheme=SequentialScheme(4, 2)),
num_classes=2,
which_sources='targets')
assert_raises(ValueError, list, wrapper.get_epoch_iterator())


class VerifyWarningHandler(logging.Handler):
def __init__(self, *args, **kwargs):
Expand Down

0 comments on commit c0213e6

Please sign in to comment.