179 lines
6.1 KiB
Python
179 lines
6.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import mmcv
|
|
import numpy as np
|
|
from PIL import Image
|
|
|
|
from .builder import DATASETS
|
|
from .custom import CustomDataset
|
|
|
|
|
|
@DATASETS.register_module()
|
|
class XMLDataset(CustomDataset):
|
|
"""XML dataset for detection.
|
|
|
|
Args:
|
|
min_size (int | float, optional): The minimum size of bounding
|
|
boxes in the images. If the size of a bounding box is less than
|
|
``min_size``, it would be add to ignored field.
|
|
img_subdir (str): Subdir where images are stored. Default: JPEGImages.
|
|
ann_subdir (str): Subdir where annotations are. Default: Annotations.
|
|
"""
|
|
|
|
def __init__(self,
|
|
min_size=None,
|
|
img_subdir='JPEGImages',
|
|
ann_subdir='Annotations',
|
|
**kwargs):
|
|
assert self.CLASSES or kwargs.get(
|
|
'classes', None), 'CLASSES in `XMLDataset` can not be None.'
|
|
self.img_subdir = img_subdir
|
|
self.ann_subdir = ann_subdir
|
|
super(XMLDataset, self).__init__(**kwargs)
|
|
self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
|
|
self.min_size = min_size
|
|
|
|
def load_annotations(self, ann_file):
|
|
"""Load annotation from XML style ann_file.
|
|
|
|
Args:
|
|
ann_file (str): Path of XML file.
|
|
|
|
Returns:
|
|
list[dict]: Annotation info from XML file.
|
|
"""
|
|
|
|
data_infos = []
|
|
img_ids = mmcv.list_from_file(ann_file)
|
|
for img_id in img_ids:
|
|
filename = osp.join(self.img_subdir, f'{img_id}.jpg')
|
|
xml_path = osp.join(self.img_prefix, self.ann_subdir,
|
|
f'{img_id}.xml')
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
size = root.find('size')
|
|
if size is not None:
|
|
width = int(size.find('width').text)
|
|
height = int(size.find('height').text)
|
|
else:
|
|
img_path = osp.join(self.img_prefix, filename)
|
|
img = Image.open(img_path)
|
|
width, height = img.size
|
|
data_infos.append(
|
|
dict(id=img_id, filename=filename, width=width, height=height))
|
|
|
|
return data_infos
|
|
|
|
def _filter_imgs(self, min_size=32):
|
|
"""Filter images too small or without annotation."""
|
|
valid_inds = []
|
|
for i, img_info in enumerate(self.data_infos):
|
|
if min(img_info['width'], img_info['height']) < min_size:
|
|
continue
|
|
if self.filter_empty_gt:
|
|
img_id = img_info['id']
|
|
xml_path = osp.join(self.img_prefix, self.ann_subdir,
|
|
f'{img_id}.xml')
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
for obj in root.findall('object'):
|
|
name = obj.find('name').text
|
|
if name in self.CLASSES:
|
|
valid_inds.append(i)
|
|
break
|
|
else:
|
|
valid_inds.append(i)
|
|
return valid_inds
|
|
|
|
def get_ann_info(self, idx):
|
|
"""Get annotation from XML file by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
dict: Annotation info of specified index.
|
|
"""
|
|
|
|
img_id = self.data_infos[idx]['id']
|
|
xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
bboxes = []
|
|
labels = []
|
|
bboxes_ignore = []
|
|
labels_ignore = []
|
|
for obj in root.findall('object'):
|
|
name = obj.find('name').text
|
|
if name not in self.CLASSES:
|
|
continue
|
|
label = self.cat2label[name]
|
|
difficult = obj.find('difficult')
|
|
difficult = 0 if difficult is None else int(difficult.text)
|
|
bnd_box = obj.find('bndbox')
|
|
# TODO: check whether it is necessary to use int
|
|
# Coordinates may be float type
|
|
bbox = [
|
|
int(float(bnd_box.find('xmin').text)),
|
|
int(float(bnd_box.find('ymin').text)),
|
|
int(float(bnd_box.find('xmax').text)),
|
|
int(float(bnd_box.find('ymax').text))
|
|
]
|
|
ignore = False
|
|
if self.min_size:
|
|
assert not self.test_mode
|
|
w = bbox[2] - bbox[0]
|
|
h = bbox[3] - bbox[1]
|
|
if w < self.min_size or h < self.min_size:
|
|
ignore = True
|
|
if difficult or ignore:
|
|
bboxes_ignore.append(bbox)
|
|
labels_ignore.append(label)
|
|
else:
|
|
bboxes.append(bbox)
|
|
labels.append(label)
|
|
if not bboxes:
|
|
bboxes = np.zeros((0, 4))
|
|
labels = np.zeros((0, ))
|
|
else:
|
|
bboxes = np.array(bboxes, ndmin=2) - 1
|
|
labels = np.array(labels)
|
|
if not bboxes_ignore:
|
|
bboxes_ignore = np.zeros((0, 4))
|
|
labels_ignore = np.zeros((0, ))
|
|
else:
|
|
bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
|
|
labels_ignore = np.array(labels_ignore)
|
|
ann = dict(
|
|
bboxes=bboxes.astype(np.float32),
|
|
labels=labels.astype(np.int64),
|
|
bboxes_ignore=bboxes_ignore.astype(np.float32),
|
|
labels_ignore=labels_ignore.astype(np.int64))
|
|
return ann
|
|
|
|
def get_cat_ids(self, idx):
|
|
"""Get category ids in XML file by index.
|
|
|
|
Args:
|
|
idx (int): Index of data.
|
|
|
|
Returns:
|
|
list[int]: All categories in the image of specified index.
|
|
"""
|
|
|
|
cat_ids = []
|
|
img_id = self.data_infos[idx]['id']
|
|
xml_path = osp.join(self.img_prefix, self.ann_subdir, f'{img_id}.xml')
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
for obj in root.findall('object'):
|
|
name = obj.find('name').text
|
|
if name not in self.CLASSES:
|
|
continue
|
|
label = self.cat2label[name]
|
|
cat_ids.append(label)
|
|
|
|
return cat_ids
|