Skip to content
Snippets Groups Projects
Commit 9d38a278 authored by Kai Chen's avatar Kai Chen
Browse files

add XMLDataset and make VOCDataset a child of it

parent 178b5baa
No related branches found
No related tags found
No related merge requests found
from .custom import CustomDataset from .custom import CustomDataset
from .xml_style import XMLDataset
from .coco import CocoDataset from .coco import CocoDataset
from .voc import VOCDataset from .voc import VOCDataset
from .loader import GroupSampler, DistributedGroupSampler, build_dataloader from .loader import GroupSampler, DistributedGroupSampler, build_dataloader
...@@ -7,7 +8,7 @@ from .concat_dataset import ConcatDataset ...@@ -7,7 +8,7 @@ from .concat_dataset import ConcatDataset
from .repeat_dataset import RepeatDataset from .repeat_dataset import RepeatDataset
__all__ = [ __all__ = [
'CustomDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler', 'CustomDataset', 'XMLDataset', 'CocoDataset', 'VOCDataset', 'GroupSampler',
'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale',
'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset' 'show_ann', 'get_dataset', 'ConcatDataset', 'RepeatDataset'
] ]
import os.path as osp from .xml_style import XMLDataset
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .custom import CustomDataset class VOCDataset(XMLDataset):
class VOCDataset(CustomDataset):
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
...@@ -15,68 +9,10 @@ class VOCDataset(CustomDataset): ...@@ -15,68 +9,10 @@ class VOCDataset(CustomDataset):
'tvmonitor') 'tvmonitor')
def __init__(self, **kwargs): def __init__(self, **kwargs):
assert not kwargs.get('with_mask', False)
super(VOCDataset, self).__init__(**kwargs) super(VOCDataset, self).__init__(**kwargs)
self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)} if 'VOC2007' in self.img_prefix:
self.year = 2007
def load_annotations(self, ann_file): elif 'VOC2012' in self.img_prefix:
img_infos = [] self.year = 2012
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.cat2label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else: else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 raise ValueError('Cannot infer dataset year from img_prefix')
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann
import os.path as osp
import xml.etree.ElementTree as ET
import mmcv
import numpy as np
from .custom import CustomDataset
class XMLDataset(CustomDataset):
def __init__(self, **kwargs):
super(XMLDataset, self).__init__(**kwargs)
self.cat2label = {cat: i + 1 for i, cat in enumerate(self.CLASSES)}
def load_annotations(self, ann_file):
img_infos = []
img_ids = mmcv.list_from_file(ann_file)
for img_id in img_ids:
filename = 'JPEGImages/{}.jpg'.format(img_id)
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
img_infos.append(
dict(id=img_id, filename=filename, width=width, height=height))
return img_infos
def get_ann_info(self, idx):
img_id = self.img_infos[idx]['id']
xml_path = osp.join(self.img_prefix, 'Annotations',
'{}.xml'.format(img_id))
tree = ET.parse(xml_path)
root = tree.getroot()
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for obj in root.findall('object'):
name = obj.find('name').text
label = self.cat2label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
if difficult:
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4))
labels = np.zeros((0, ))
else:
bboxes = np.array(bboxes, ndmin=2) - 1
labels = np.array(labels)
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4))
labels_ignore = np.zeros((0, ))
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
labels_ignore = np.array(labels_ignore)
ann = dict(
bboxes=bboxes.astype(np.float32),
labels=labels.astype(np.int64),
bboxes_ignore=bboxes_ignore.astype(np.float32),
labels_ignore=labels_ignore.astype(np.int64))
return ann
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment