246 lines
8.0 KiB
Python
246 lines
8.0 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
from generators.generator import Generator
|
|
from utils.image import read_image_rgb
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
from six import raise_from
|
|
import csv
|
|
import sys
|
|
import os.path
|
|
from collections import OrderedDict
|
|
|
|
|
|
def _parse(value, function, fmt):
|
|
"""
|
|
Parse a string into a value, and format a nice ValueError if it fails.
|
|
|
|
Returns `function(value)`.
|
|
Any `ValueError` raised is catched and a new `ValueError` is raised
|
|
with message `fmt.format(e)`, where `e` is the caught `ValueError`.
|
|
"""
|
|
try:
|
|
return function(value)
|
|
except ValueError as e:
|
|
raise_from(ValueError(fmt.format(e)), None)
|
|
|
|
|
|
def _read_classes(csv_reader):
|
|
"""
|
|
Parse the classes file given by csv_reader.
|
|
"""
|
|
result = OrderedDict()
|
|
for line, row in enumerate(csv_reader):
|
|
line += 1
|
|
|
|
try:
|
|
class_name, class_id = row
|
|
except ValueError:
|
|
raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None)
|
|
class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line))
|
|
|
|
if class_name in result:
|
|
raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name))
|
|
result[class_name] = class_id
|
|
return result
|
|
|
|
|
|
def _read_annotations(csv_reader, classes, base_dir):
|
|
"""
|
|
Read annotations from the csv_reader.
|
|
"""
|
|
result = OrderedDict()
|
|
for line, row in enumerate(csv_reader):
|
|
line += 1
|
|
if line == 1:
|
|
continue
|
|
try:
|
|
img_file, x1, y1, x2, y2, class_name = row[:6]
|
|
except ValueError:
|
|
raise_from(ValueError(
|
|
'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)),
|
|
None)
|
|
img_file = base_dir + '/' + img_file
|
|
if img_file not in result:
|
|
result[img_file] = []
|
|
|
|
# If a row contains only an image path, it's an image without annotations.
|
|
if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''):
|
|
continue
|
|
|
|
x1 = _parse(x1, float, 'line {}: malformed x1: {{}}'.format(line))
|
|
y1 = _parse(y1, float, 'line {}: malformed y1: {{}}'.format(line))
|
|
x2 = _parse(x2, float, 'line {}: malformed x2: {{}}'.format(line))
|
|
y2 = _parse(y2, float, 'line {}: malformed y2: {{}}'.format(line))
|
|
|
|
# Check that the bounding box is valid.
|
|
if x2 <= x1:
|
|
raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1))
|
|
if y2 <= y1:
|
|
raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1))
|
|
|
|
# check if the current class name is correctly present
|
|
if class_name not in classes:
|
|
raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes))
|
|
|
|
result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': classes[class_name]})
|
|
return result
|
|
|
|
|
|
def _open_for_csv(path):
|
|
"""
|
|
Open a file with flags suitable for csv.reader.
|
|
|
|
This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines".
|
|
"""
|
|
if sys.version_info[0] < 3:
|
|
return open(path, 'rb')
|
|
else:
|
|
return open(path, 'r', newline='')
|
|
|
|
|
|
class CSVGenerator(Generator):
|
|
"""
|
|
Generate data for multiple custom json dataset.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data_files,
|
|
class_files,
|
|
base_dirs,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Initialize a CSV data generator.
|
|
|
|
Args
|
|
csv_data_file: Path to the CSV annotations file.
|
|
csv_class_file: Path to the CSV classes file.
|
|
base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file).
|
|
"""
|
|
self.image_names = []
|
|
self.image_data = {}
|
|
if isinstance(data_files, str):
|
|
# list of str
|
|
data_files = [data_files]
|
|
if isinstance(class_files, str) or isinstance(class_files, dict) :
|
|
# list of dict or str
|
|
class_files = [class_files]
|
|
if isinstance(base_dirs, str):
|
|
# list of str
|
|
base_dirs = [base_dirs]
|
|
|
|
# parse the provided class file
|
|
n_dataset = len(data_files)
|
|
self.classes = []
|
|
self.image_data = {}
|
|
self.image_names = []
|
|
|
|
for set_idx in range(n_dataset):
|
|
# class_name --> class_id
|
|
# [{'person': 0, ...}]
|
|
class_file = class_files[set_idx]
|
|
if isinstance(class_file, dict):
|
|
# list of dict {str:int}
|
|
self.classes.append(class_file)
|
|
else:
|
|
try:
|
|
with _open_for_csv(class_file) as file:
|
|
class_file = _read_classes(csv.reader(file, delimiter=','))
|
|
self.classes.append(class_file)
|
|
except ValueError as e:
|
|
raise_from(ValueError('invalid CSV class file: {}: {}'.format(class_file, e)), None)
|
|
|
|
data_file = data_files[set_idx]
|
|
base_dir = base_dirs[set_idx]
|
|
# csv with img_path, x1, y1, x2, y2, class_name
|
|
try:
|
|
with _open_for_csv(data_file) as file:
|
|
# {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...}
|
|
image_data = _read_annotations(csv.reader(file, delimiter=','), class_file, base_dir)
|
|
self.image_data.update(image_data)
|
|
except ValueError as e:
|
|
raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(data_file, e)), None)
|
|
self.image_names = list(self.image_data.keys())
|
|
|
|
super(CSVGenerator, self).__init__(**kwargs)
|
|
|
|
def size(self):
|
|
"""
|
|
Size of the dataset.
|
|
"""
|
|
return len(self.image_names)
|
|
|
|
def num_classes(self):
|
|
"""
|
|
Number of classes in the dataset.
|
|
"""
|
|
return np.max([item.values() for item in self.classes])+1
|
|
|
|
def has_label(self, label):
|
|
"""
|
|
Return True if label is a known label.
|
|
"""
|
|
# self.labels 是 class_id --> class_name 的 dict
|
|
return label < self.num_classes()
|
|
|
|
def has_name(self, name):
|
|
"""
|
|
Returns True if name is a known class.
|
|
"""
|
|
for i in range(len(self.classes)):
|
|
if name in self.classes[i]:
|
|
return True
|
|
return False
|
|
|
|
def name_to_label(self, name):
|
|
print('do not support for multiple dataset')
|
|
return name
|
|
|
|
def label_to_name(self, label):
|
|
"""
|
|
Map label to name.
|
|
"""
|
|
print('do not support for multiple dataset')
|
|
return label
|
|
|
|
def image_path(self, image_index):
|
|
"""
|
|
Returns the image path for image_index.
|
|
"""
|
|
return self.image_names[image_index]
|
|
|
|
def image_aspect_ratio(self, image_index):
|
|
"""
|
|
Compute the aspect ratio for an image with image_index.
|
|
"""
|
|
# PIL is fast for metadata
|
|
image = Image.open(self.image_path(image_index))
|
|
return float(image.width) / float(image.height)
|
|
|
|
def load_image(self, image_index):
|
|
"""
|
|
Load an image at the image_index.
|
|
"""
|
|
return read_image_rgb(self.image_path(image_index))
|
|
|
|
def load_annotations(self, image_index):
|
|
"""
|
|
Load annotations for an image_index.
|
|
"""
|
|
path = self.image_names[image_index]
|
|
annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))}
|
|
|
|
for idx, annot in enumerate(self.image_data[path]):
|
|
annotations['labels'] = np.concatenate((annotations['labels'], [annot['class']]))
|
|
annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[
|
|
float(annot['x1']),
|
|
float(annot['y1']),
|
|
float(annot['x2']),
|
|
float(annot['y2']),
|
|
]]))
|
|
|
|
return annotations
|