Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcBS committed Mar 21, 2017
2 parents 7f8cca3 + 02c3b89 commit d936091
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions keras_wrapper/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,8 @@ def setRawInput(self, path_list, set_name, type='file-name', id='raw-text', over
if id not in self.ids_inputs or overwrite_split:
self.ids_inputs.append(id)
self.types_inputs.append(type)
self.optional_inputs.append(id) # This is always optional
if id not in self.optional_inputs:
self.optional_inputs.append(id) # This is always optional
elif id in keys_X_set and not overwrite_split:
raise Exception('An input with id "' + id + '" is already loaded into the Database.')

Expand Down Expand Up @@ -723,7 +724,7 @@ def setInput(self, path_list, set_name, type='raw-image', id='image', repeat_set
elif id in keys_X_set and not overwrite_split:
raise Exception('An input with id "' + id + '" is already loaded into the Database.')

if not required:
if not required and id not in self.optional_inputs:
self.optional_inputs.append(id)

if type not in self.__accepted_types_inputs:
Expand Down Expand Up @@ -819,7 +820,8 @@ def setRawOutput(self, path_list, set_name, type='file-name', id='raw-text', ove
if id not in self.ids_inputs or overwrite_split:
self.ids_inputs.append(id)
self.types_inputs.append(type)
self.optional_inputs.append(id) # This is always optional
if id not in self.optional_inputs:
self.optional_inputs.append(id) # This is always optional

elif id in keys_Y_set and not overwrite_split:
raise Exception('An input with id "' + id + '" is already loaded into the Database.')
Expand Down Expand Up @@ -3377,17 +3379,18 @@ def __checkLengthSet(self, set_name):
"""
if eval('self.loaded_' + set_name + '[0] and self.loaded_' + set_name + '[1]'):
lengths = []
plot_ids_in = []
for id_in in self.ids_inputs:
if id_in not in self.optional_inputs:
exec ('lengths.append(len(self.X_' + set_name + '[id_in]))')
plot_ids_in.append(id_in)
exec('lengths.append(len(self.X_' + set_name + '[id_in]))')
for id_out in self.ids_outputs:
exec ('lengths.append(len(self.Y_' + set_name + '[id_out]))')
if lengths[1:] != lengths[:-1]:
raise Exception('Inputs and outputs size '
'(' + str(lengths) + ') for "' + set_name + '" set do not match.\n'
'\t Inputs:' + str(self.ids_inputs) + ''
'\t Outputs:' + str(
self.ids_outputs))
'\t Inputs:' + str(plot_ids_in) + ''
'\t Outputs:' + str(self.ids_outputs))

def __getNextSamples(self, k, set_name):
"""
Expand Down

0 comments on commit d936091

Please sign in to comment.