# -*- coding: utf-8 -*- from generators.generator import Generator from utils.image import read_image_rgb import numpy as np from PIL import Image from six import raise_from import csv import sys import os.path from collections import OrderedDict def _parse(value, function, fmt): """ Parse a string into a value, and format a nice ValueError if it fails. Returns `function(value)`. Any `ValueError` raised is catched and a new `ValueError` is raised with message `fmt.format(e)`, where `e` is the caught `ValueError`. """ try: return function(value) except ValueError as e: raise_from(ValueError(fmt.format(e)), None) def _read_classes(csv_reader): """ Parse the classes file given by csv_reader. """ result = OrderedDict() for line, row in enumerate(csv_reader): line += 1 try: class_name, class_id = row except ValueError: raise_from(ValueError('line {}: format should be \'class_name,class_id\''.format(line)), None) class_id = _parse(class_id, int, 'line {}: malformed class ID: {{}}'.format(line)) if class_name in result: raise ValueError('line {}: duplicate class name: \'{}\''.format(line, class_name)) result[class_name] = class_id return result def _read_annotations(csv_reader, classes, base_dir): """ Read annotations from the csv_reader. """ result = OrderedDict() for line, row in enumerate(csv_reader): line += 1 if line == 1: continue try: img_file, x1, y1, x2, y2, class_name = row[:6] except ValueError: raise_from(ValueError( 'line {}: format should be \'img_file,x1,y1,x2,y2,class_name\' or \'img_file,,,,,\''.format(line)), None) img_file = base_dir + '/' + img_file if img_file not in result: result[img_file] = [] # If a row contains only an image path, it's an image without annotations. if (x1, y1, x2, y2, class_name) == ('', '', '', '', ''): continue x1 = _parse(x1, float, 'line {}: malformed x1: {{}}'.format(line)) y1 = _parse(y1, float, 'line {}: malformed y1: {{}}'.format(line)) x2 = _parse(x2, float, 'line {}: malformed x2: {{}}'.format(line)) y2 = _parse(y2, float, 'line {}: malformed y2: {{}}'.format(line)) # Check that the bounding box is valid. if x2 <= x1: raise ValueError('line {}: x2 ({}) must be higher than x1 ({})'.format(line, x2, x1)) if y2 <= y1: raise ValueError('line {}: y2 ({}) must be higher than y1 ({})'.format(line, y2, y1)) # check if the current class name is correctly present if class_name not in classes: raise ValueError('line {}: unknown class name: \'{}\' (classes: {})'.format(line, class_name, classes)) result[img_file].append({'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'class': classes[class_name]}) return result def _open_for_csv(path): """ Open a file with flags suitable for csv.reader. This is different for python2 it means with mode 'rb', for python3 this means 'r' with "universal newlines". """ if sys.version_info[0] < 3: return open(path, 'rb') else: return open(path, 'r', newline='') class CSVGenerator(Generator): """ Generate data for multiple custom json dataset. """ def __init__( self, data_files, class_files, base_dirs, **kwargs ): """ Initialize a CSV data generator. Args csv_data_file: Path to the CSV annotations file. csv_class_file: Path to the CSV classes file. base_dir: Directory w.r.t. where the files are to be searched (defaults to the directory containing the csv_data_file). """ self.image_names = [] self.image_data = {} if isinstance(data_files, str): # list of str data_files = [data_files] if isinstance(class_files, str) or isinstance(class_files, dict) : # list of dict or str class_files = [class_files] if isinstance(base_dirs, str): # list of str base_dirs = [base_dirs] # parse the provided class file n_dataset = len(data_files) self.classes = [] self.image_data = {} self.image_names = [] for set_idx in range(n_dataset): # class_name --> class_id # [{'person': 0, ...}] class_file = class_files[set_idx] if isinstance(class_file, dict): # list of dict {str:int} self.classes.append(class_file) else: try: with _open_for_csv(class_file) as file: class_file = _read_classes(csv.reader(file, delimiter=',')) self.classes.append(class_file) except ValueError as e: raise_from(ValueError('invalid CSV class file: {}: {}'.format(class_file, e)), None) data_file = data_files[set_idx] base_dir = base_dirs[set_idx] # csv with img_path, x1, y1, x2, y2, class_name try: with _open_for_csv(data_file) as file: # {'img_path1':[{'x1':xx,'y1':xx,'x2':xx,'y2':xx,'class':xx}...],...} image_data = _read_annotations(csv.reader(file, delimiter=','), class_file, base_dir) self.image_data.update(image_data) except ValueError as e: raise_from(ValueError('invalid CSV annotations file: {}: {}'.format(data_file, e)), None) self.image_names = list(self.image_data.keys()) super(CSVGenerator, self).__init__(**kwargs) def size(self): """ Size of the dataset. """ return len(self.image_names) def num_classes(self): """ Number of classes in the dataset. """ return np.max([item.values() for item in self.classes])+1 def has_label(self, label): """ Return True if label is a known label. """ # self.labels 是 class_id --> class_name 的 dict return label < self.num_classes() def has_name(self, name): """ Returns True if name is a known class. """ for i in range(len(self.classes)): if name in self.classes[i]: return True return False def name_to_label(self, name): print('do not support for multiple dataset') return name def label_to_name(self, label): """ Map label to name. """ print('do not support for multiple dataset') return label def image_path(self, image_index): """ Returns the image path for image_index. """ return self.image_names[image_index] def image_aspect_ratio(self, image_index): """ Compute the aspect ratio for an image with image_index. """ # PIL is fast for metadata image = Image.open(self.image_path(image_index)) return float(image.width) / float(image.height) def load_image(self, image_index): """ Load an image at the image_index. """ return read_image_rgb(self.image_path(image_index)) def load_annotations(self, image_index): """ Load annotations for an image_index. """ path = self.image_names[image_index] annotations = {'labels': np.empty((0,), dtype=np.int32), 'bboxes': np.empty((0, 4))} for idx, annot in enumerate(self.image_data[path]): annotations['labels'] = np.concatenate((annotations['labels'], [annot['class']])) annotations['bboxes'] = np.concatenate((annotations['bboxes'], [[ float(annot['x1']), float(annot['y1']), float(annot['x2']), float(annot['y2']), ]])) return annotations