159 lines
5.1 KiB
Python
159 lines
5.1 KiB
Python
from augmentor import train_aug, val_aug
|
|
|
|
|
|
def create_generators(input_size=512,
|
|
batch_size=16,
|
|
data_root=None,
|
|
data_root_val=None,
|
|
dataset_type='pascal',
|
|
data_augmentation_chain=None,
|
|
annotations_path=None,
|
|
val_annotations_path=None,
|
|
classes_path=None,
|
|
center_sampling_radius=1.,
|
|
anchor_param=None):
|
|
"""
|
|
Create generators for training and validation.
|
|
|
|
Args
|
|
args: parseargs object containing configuration for generators.
|
|
preprocess_image: Function that preprocesses an image for the network.
|
|
"""
|
|
common_args = {
|
|
'batch_size': batch_size,
|
|
'input_size': input_size,
|
|
'center_sampling_radius': center_sampling_radius,
|
|
'anchor_param': anchor_param
|
|
}
|
|
|
|
if data_augmentation_chain is None:
|
|
data_augmentation_chain = train_aug(input_size=input_size)
|
|
data_augmentation_chain2 = val_aug(input_size=input_size)
|
|
if data_root_val is None:
|
|
data_root_val = data_root
|
|
|
|
if dataset_type == 'pascal':
|
|
if annotations_path is None:
|
|
annotations_path = 'trainval'
|
|
if val_annotations_path is None:
|
|
val_annotations_path = 'test'
|
|
|
|
from generators.pascal import PascalVocGenerator
|
|
train_generator = PascalVocGenerator(
|
|
data_root,
|
|
annotations_path,
|
|
skip_difficult=True,
|
|
transformations=data_augmentation_chain,
|
|
**common_args
|
|
)
|
|
|
|
validation_generator = PascalVocGenerator(
|
|
data_root_val,
|
|
val_annotations_path,
|
|
skip_difficult=True,
|
|
shuffle_groups=False,
|
|
transformations=data_augmentation_chain2,
|
|
**common_args
|
|
)
|
|
|
|
elif dataset_type == 'widerperson':
|
|
if annotations_path is None:
|
|
annotations_path = 'train'
|
|
if val_annotations_path is None:
|
|
val_annotations_path = 'val'
|
|
|
|
from generators.widerperson import PersonGenerator
|
|
train_generator = PersonGenerator(
|
|
data_root,
|
|
annotations_path,
|
|
transformations=data_augmentation_chain,
|
|
**common_args
|
|
)
|
|
|
|
validation_generator = PersonGenerator(
|
|
data_root_val,
|
|
val_annotations_path,
|
|
shuffle_groups=False,
|
|
transformations=data_augmentation_chain2,
|
|
**common_args
|
|
)
|
|
elif dataset_type == 'csv':
|
|
# class_name, class_id
|
|
from generators.csv_ import CSVGenerator
|
|
# classes_path = {'1':0}
|
|
assert annotations_path is not None
|
|
if classes_path is None:
|
|
import warnings
|
|
warnings.warn(
|
|
'no class file provided, use {int:int}')
|
|
classes_path = dict(zip(list(range(200)), list(range(200))))
|
|
train_generator = CSVGenerator(
|
|
annotations_path,
|
|
classes_path,
|
|
base_dir=data_root,
|
|
transformations=data_augmentation_chain,
|
|
**common_args
|
|
)
|
|
|
|
if val_annotations_path:
|
|
validation_generator = CSVGenerator(
|
|
val_annotations_path,
|
|
classes_path,
|
|
base_dir=data_root_val,
|
|
shuffle_groups=False,
|
|
transformations=data_augmentation_chain2,
|
|
**common_args
|
|
)
|
|
else:
|
|
validation_generator = None
|
|
elif dataset_type == 'csvs':
|
|
# class_name, class_id
|
|
from generators.csv_multi import CSVGenerator
|
|
# classes_path = {'1':0}
|
|
assert annotations_path is not None
|
|
assert classes_path is not None
|
|
train_generator = CSVGenerator(
|
|
annotations_path,
|
|
classes_path,
|
|
base_dirs=data_root,
|
|
transformations=data_augmentation_chain,
|
|
**common_args
|
|
)
|
|
|
|
if val_annotations_path:
|
|
validation_generator = CSVGenerator(
|
|
val_annotations_path,
|
|
classes_path,
|
|
base_dirs=data_root_val,
|
|
shuffle_groups=False,
|
|
transformations=data_augmentation_chain2,
|
|
**common_args
|
|
)
|
|
else:
|
|
validation_generator = None
|
|
elif dataset_type == 'coco':
|
|
from generators.coco import CocoGenerator
|
|
if annotations_path is None:
|
|
annotations_path = 'train2017'
|
|
|
|
if val_annotations_path is None:
|
|
val_annotations_path = 'val2017'
|
|
train_generator = CocoGenerator(
|
|
data_root,
|
|
annotations_path,
|
|
transformations=data_augmentation_chain,
|
|
**common_args
|
|
)
|
|
|
|
validation_generator = CocoGenerator(
|
|
data_root_val,
|
|
val_annotations_path,
|
|
shuffle_groups=False,
|
|
transformations=data_augmentation_chain2,
|
|
**common_args
|
|
)
|
|
else:
|
|
raise ValueError('Invalid data type received: {}'.format(dataset_type))
|
|
|
|
return train_generator, validation_generator
|