diff --git a/luminoth/tools/dataset/readers/object_detection/coco.py b/luminoth/tools/dataset/readers/object_detection/coco.py index cb6d801b..3090cb61 100644 --- a/luminoth/tools/dataset/readers/object_detection/coco.py +++ b/luminoth/tools/dataset/readers/object_detection/coco.py @@ -45,14 +45,24 @@ def __init__(self, data_dir, split, year=DEFAULT_YEAR, for annotation in annotations_json['annotations']: image_id = annotation['image_id'] x, y, width, height = annotation['bbox'] + if not self.merge_classes: + try: + label_id = self.classes.index( + category_to_name[annotation['category_id']] + ) + except ValueError: + # Class may have gotten filtered by: + # --only-classes or --limit-classes + continue + else: + label_id = 0 + self._image_to_bboxes.setdefault(image_id, []).append({ 'xmin': x, 'ymin': y, 'xmax': x + width, 'ymax': y + height, - 'label': self.classes.index( - category_to_name[annotation['category_id']] - ), + 'label': label_id, }) self._image_to_details = {} diff --git a/luminoth/tools/dataset/readers/object_detection/object_detection_reader.py b/luminoth/tools/dataset/readers/object_detection/object_detection_reader.py index 9f58042e..7fdadf54 100644 --- a/luminoth/tools/dataset/readers/object_detection/object_detection_reader.py +++ b/luminoth/tools/dataset/readers/object_detection/object_detection_reader.py @@ -23,7 +23,8 @@ class ObjectDetectionReader(BaseReader): Iterate over all records. """ def __init__(self, only_classes=None, only_images=None, - limit_examples=None, limit_classes=None, seed=None, **kwargs): + limit_examples=None, limit_classes=None, merge_classes=False, + seed=None, **kwargs): """ Args: - only_classes: string or list of strings used as a class @@ -47,6 +48,7 @@ def __init__(self, only_classes=None, only_images=None, self._limit_examples = limit_examples self._limit_classes = limit_classes + self.merge_classes = merge_classes random.seed(seed) self._total = None diff --git a/luminoth/tools/dataset/readers/object_detection/pascalvoc.py b/luminoth/tools/dataset/readers/object_detection/pascalvoc.py index 0d6bf32e..1482589f 100644 --- a/luminoth/tools/dataset/readers/object_detection/pascalvoc.py +++ b/luminoth/tools/dataset/readers/object_detection/pascalvoc.py @@ -97,10 +97,15 @@ def iterate(self): gt_boxes = [] for b in annotation['object']: - try: - label_id = self.classes.index(b['name']) - except ValueError: - continue + if not self.merge_classes: + try: + label_id = self.classes.index(b['name']) + except ValueError: + # Class may have gotten filtered by: + # --only-classes or --limit-classes + continue + else: + label_id = 0 gt_boxes.append({ 'label': label_id, diff --git a/luminoth/tools/dataset/transform.py b/luminoth/tools/dataset/transform.py index 89ed278d..b139011e 100644 --- a/luminoth/tools/dataset/transform.py +++ b/luminoth/tools/dataset/transform.py @@ -14,7 +14,7 @@ def get_output_subfolder(only_classes, only_images, limit_examples, Returns: subfolder name for records """ if only_classes is not None: - return 'classes-{}'.format(only_classes) + return 'classes-{}'.format('-'.join(only_classes)) elif only_images is not None: return 'only-{}'.format(only_images) elif limit_examples is not None and limit_classes is not None: @@ -30,7 +30,8 @@ def get_output_subfolder(only_classes, only_images, limit_examples, @click.option('--data-dir', help='Where to locate the original data.') @click.option('--output-dir', help='Where to save the transformed data.') @click.option('splits', '--split', required=True, multiple=True, help='Which splits to transform.') # noqa -@click.option('--only-classes', help='Whitelist of classes.') +@click.option('--only-classes', multiple=True, help='Whitelist of classes.') +@click.option('--merge-classes', help='Merge all classes into a single class') @click.option('--only-images', help='Create dataset with specific examples.') @click.option('--limit-examples', type=int, help='Limit dataset with to the first `N` examples.') # noqa @click.option('--limit-classes', type=int, help='Limit dataset with `N` random classes.') # noqa @@ -38,8 +39,8 @@ def get_output_subfolder(only_classes, only_images, limit_examples, @click.option('overrides', '--override', '-o', multiple=True, help='Custom parameters for readers.') # noqa @click.option('--debug', is_flag=True, help='Set level logging to DEBUG.') def transform(dataset_reader, data_dir, output_dir, splits, only_classes, - only_images, limit_examples, limit_classes, seed, overrides, - debug): + merge_classes, only_images, limit_examples, limit_classes, seed, + overrides, debug): """ Prepares dataset for ingestion. @@ -67,6 +68,8 @@ def transform(dataset_reader, data_dir, output_dir, splits, only_classes, # All splits must have a consistent set of classes. classes = None + merge_classes = merge_classes in ('True', 'true', 'TRUE') + reader_kwargs = parse_override(overrides) try: @@ -74,9 +77,9 @@ def transform(dataset_reader, data_dir, output_dir, splits, only_classes, # Create instance of reader. split_reader = reader( data_dir, split, - only_classes=only_classes, only_images=only_images, - limit_examples=limit_examples, limit_classes=limit_classes, - seed=seed, **reader_kwargs + only_classes=only_classes, merge_classes=merge_classes, + only_images=only_images, limit_examples=limit_examples, + limit_classes=limit_classes, seed=seed, **reader_kwargs ) if classes is None: diff --git a/luminoth/tools/dataset/writers/object_detection_writer.py b/luminoth/tools/dataset/writers/object_detection_writer.py index e826b446..6dbf2d82 100644 --- a/luminoth/tools/dataset/writers/object_detection_writer.py +++ b/luminoth/tools/dataset/writers/object_detection_writer.py @@ -54,7 +54,12 @@ def save(self): # Save classes in simple json format for later use. classes_file = os.path.join(self._output_dir, CLASSES_FILENAME) - json.dump(self._reader.classes, tf.gfile.GFile(classes_file, 'w')) + if self._reader.merge_classes: + # Don't assign a name to the class if its a merge of several others + json.dump([''], tf.gfile.GFile(classes_file, 'w')) + else: + json.dump(self._reader.classes, tf.gfile.GFile(classes_file, 'w')) + record_file = os.path.join( self._output_dir, '{}.tfrecords'.format(self._split)) writer = tf.python_io.TFRecordWriter(record_file)