diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index b9d27c2b896b0ef8240fd169875614320b25d3ff..6e266650421d1cbc1355b77894e618daa8e83401 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -1,4 +1,5 @@ from .custom import CustomDataset +from .xml_style import XMLDataset from .coco import CocoDataset from .voc import VOCDataset from .loader import GroupSampler, DistributedGroupSampler, build_dataloader @@ -7,7 +8,7 @@ from .concat_dataset import ConcatDataset from .repeat_dataset import RepeatDataset __all__ = [ - 'CustomDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', + 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset' ] diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py index 152a13c0228e91e6b6900b68085a1b2e9e92311c..ba1c77273ba57a2b8a59b9dc46544e4e71290b5a 100644 --- a/mmdet/datasets/voc.py +++ b/mmdet/datasets/voc.py @@ -1,13 +1,7 @@ -import os.path as osp -import xml.etree.ElementTree as ET +from .xml_style import XMLDataset -import mmcv -import numpy as np -from .custom import CustomDataset - - -class VOCDataset(CustomDataset): +class VOCDataset(XMLDataset): CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', @@ -15,68 +9,10 @@ class VOCDataset(CustomDataset): 'tvmonitor') def __init__(self, **kwargs): - assert not kwargs.get('with_mask', False) super(VOCDataset, self).__init__(**kwargs) - self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} - - def load_annotations(self, ann_file): - img_infos = [] - img_ids = mmcv.list_from_file(ann_file) - for img_id in img_ids: - filename = 'JPEGImages/{}.jpg'.format(img_id) - xml_path = osp.join(self.img_prefix, 'Annotations', - '{}.xml'.format(img_id)) - tree = ET.parse(xml_path) - root = tree.getroot() - size = root.find('size') - width = int(size.find('width').text) - height = int(size.find('height').text) - img_infos.append( - dict(id=img_id, filename=filename, width=width, height=height)) - return img_infos - - def get_ann_info(self, idx): - img_id = self.img_infos[idx]['id'] - xml_path = osp.join(self.img_prefix, 'Annotations', - '{}.xml'.format(img_id)) - tree = ET.parse(xml_path) - root = tree.getroot() - bboxes = [] - labels = [] - bboxes_ignore = [] - labels_ignore = [] - for obj in root.findall('object'): - name = obj.find('name').text - label = self.cat2label[name] - difficult = int(obj.find('difficult').text) - bnd_box = obj.find('bndbox') - bbox = [ - int(bnd_box.find('xmin').text), - int(bnd_box.find('ymin').text), - int(bnd_box.find('xmax').text), - int(bnd_box.find('ymax').text) - ] - if difficult: - bboxes_ignore.append(bbox) - labels_ignore.append(label) - else: - bboxes.append(bbox) - labels.append(label) - if not bboxes: - bboxes = np.zeros((0, 4)) - labels = np.zeros((0, )) - else: - bboxes = np.array(bboxes, ndmin=2) - 1 - labels = np.array(labels) - if not bboxes_ignore: - bboxes_ignore = np.zeros((0, 4)) - labels_ignore = np.zeros((0, )) + if 'VOC2007' in self.img_prefix: + self.year = 2007 + elif 'VOC2012' in self.img_prefix: + self.year = 2012 else: - bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 - labels_ignore = np.array(labels_ignore) - ann = dict( - bboxes=bboxes.astype(np.float32), - labels=labels.astype(np.int64), - bboxes_ignore=bboxes_ignore.astype(np.float32), - labels_ignore=labels_ignore.astype(np.int64)) - return ann + raise ValueError('Cannot infer dataset year from img_prefix') diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py new file mode 100644 index 0000000000000000000000000000000000000000..40e3374e6dc17be5c31b238acb47e93fceb56a56 --- /dev/null +++ b/mmdet/datasets/xml_style.py @@ -0,0 +1,76 @@ +import os.path as osp +import xml.etree.ElementTree as ET + +import mmcv +import numpy as np + +from .custom import CustomDataset + + +class XMLDataset(CustomDataset): + + def __init__(self, **kwargs): + super(XMLDataset, self).__init__(**kwargs) + self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} + + def load_annotations(self, ann_file): + img_infos = [] + img_ids = mmcv.list_from_file(ann_file) + for img_id in img_ids: + filename = 'JPEGImages/{}.jpg'.format(img_id) + xml_path = osp.join(self.img_prefix, 'Annotations', + '{}.xml'.format(img_id)) + tree = ET.parse(xml_path) + root = tree.getroot() + size = root.find('size') + width = int(size.find('width').text) + height = int(size.find('height').text) + img_infos.append( + dict(id=img_id, filename=filename, width=width, height=height)) + return img_infos + + def get_ann_info(self, idx): + img_id = self.img_infos[idx]['id'] + xml_path = osp.join(self.img_prefix, 'Annotations', + '{}.xml'.format(img_id)) + tree = ET.parse(xml_path) + root = tree.getroot() + bboxes = [] + labels = [] + bboxes_ignore = [] + labels_ignore = [] + for obj in root.findall('object'): + name = obj.find('name').text + label = self.cat2label[name] + difficult = int(obj.find('difficult').text) + bnd_box = obj.find('bndbox') + bbox = [ + int(bnd_box.find('xmin').text), + int(bnd_box.find('ymin').text), + int(bnd_box.find('xmax').text), + int(bnd_box.find('ymax').text) + ] + if difficult: + bboxes_ignore.append(bbox) + labels_ignore.append(label) + else: + bboxes.append(bbox) + labels.append(label) + if not bboxes: + bboxes = np.zeros((0, 4)) + labels = np.zeros((0, )) + else: + bboxes = np.array(bboxes, ndmin=2) - 1 + labels = np.array(labels) + if not bboxes_ignore: + bboxes_ignore = np.zeros((0, 4)) + labels_ignore = np.zeros((0, )) + else: + bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 + labels_ignore = np.array(labels_ignore) + ann = dict( + bboxes=bboxes.astype(np.float32), + labels=labels.astype(np.int64), + bboxes_ignore=bboxes_ignore.astype(np.float32), + labels_ignore=labels_ignore.astype(np.int64)) + return ann