diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 42305538ab17fb845f08dcdd2b747e4b099a624e..b9d27c2b896b0ef8240fd169875614320b25d3ff 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,12 +1,13 @@ from .custom import CustomDataset from .coco import CocoDataset +from .voc import VOCDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .utils import to_tensor, random_scale, show_ann, get_dataset from .concat_dataset import ConcatDataset from .repeat_dataset import RepeatDataset __all__ = [ - 'CustomDataset', 'CocoDataset', 'GroupSampler', 'DistributedGroupSampler', - 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', - 'get_dataset', 'ConcatDataset', 'RepeatDataset', + 'CustomDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', + 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', + 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset' ] diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py new file mode 100644 index 0000000000000000000000000000000000000000..7f278793481b8469b52d0c68bcdfc2e9d893e450 --- /dev/null +++ b/mmdet/datasets/voc.py @@ -0,0 +1,82 @@ +import os.path as osp +import xml.etree.ElementTree as ET + +import mmcv +import numpy as np + +from .custom import CustomDataset + + +class VOCDataset(CustomDataset): + + CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', + 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', + 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', + 'tvmonitor') + + def __init__(self, **kwargs): + assert not kwargs.get('with_mask', False) + super(VOCDataset, self).__init__(**kwargs) + self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} + + def load_annotations(self, ann_file): + self.img_infos = [] + img_ids = mmcv.list_from_file(ann_file) + for img_id in img_ids: + filename = 'JPEGImages/{}.jpg'.format(img_id) + xml_path = osp.join(self.img_prefix, 'Annotations', + '{}.xml'.format(img_id)) + tree = ET.parse(xml_path) + root = tree.getroot() + size = root.find('size') + width = int(size.find('width').text) + height = int(size.find('height').text) + self.img_infos.append( + dict(id=img_id, filename=filename, width=width, height=height)) + return self.img_infos + + def get_ann_info(self, idx): + img_id = self.img_infos[idx]['id'] + xml_path = osp.join(self.img_prefix, 'Annotations', + '{}.xml'.format(img_id)) + tree = ET.parse(xml_path) + root = tree.getroot() + bboxes = [] + labels = [] + bboxes_ignore = [] + labels_ignore = [] + for obj in root.findall('object'): + name = obj.find('name').text + label = self.cat2label[name] + difficult = int(obj.find('difficult').text) + bnd_box = obj.find('bndbox') + bbox = [ + int(bnd_box.find('xmin').text), + int(bnd_box.find('ymin').text), + int(bnd_box.find('xmax').text), + int(bnd_box.find('ymax').text) + ] + if difficult: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + if not bboxes: + bboxes = np.zeros((0, 4)) + labels = np.zeros((0, )) + else: + bboxes = np.array(bboxes, ndmin=2) - 1 + labels = np.array(labels) + if not bboxes_ignore: + bboxes_ignore = np.zeros((0, 4)) + labels_ignore = np.zeros((0, )) + else: + bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 + labels_ignore = np.array(labels_ignore) + ann = dict( + bboxes=bboxes.astype(np.float32), + labels=labels.astype(np.int64), + bboxes_ignore=bboxes_ignore.astype(np.float32), + labels_ignore=labels_ignore.astype(np.int64)) + return ann