159 lines
5.1 KiB
Python

from augmentor import train_aug, val_aug
def create_generators(input_size=512,
batch_size=16,
data_root=None,
data_root_val=None,
dataset_type='pascal',
data_augmentation_chain=None,
annotations_path=None,
val_annotations_path=None,
classes_path=None,
center_sampling_radius=1.,
anchor_param=None):
"""
Create generators for training and validation.
Args
args: parseargs object containing configuration for generators.
preprocess_image: Function that preprocesses an image for the network.
"""
common_args = {
'batch_size': batch_size,
'input_size': input_size,
'center_sampling_radius': center_sampling_radius,
'anchor_param': anchor_param
}
if data_augmentation_chain is None:
data_augmentation_chain = train_aug(input_size=input_size)
data_augmentation_chain2 = val_aug(input_size=input_size)
if data_root_val is None:
data_root_val = data_root
if dataset_type == 'pascal':
if annotations_path is None:
annotations_path = 'trainval'
if val_annotations_path is None:
val_annotations_path = 'test'
from generators.pascal import PascalVocGenerator
train_generator = PascalVocGenerator(
data_root,
annotations_path,
skip_difficult=True,
transformations=data_augmentation_chain,
**common_args
)
validation_generator = PascalVocGenerator(
data_root_val,
val_annotations_path,
skip_difficult=True,
shuffle_groups=False,
transformations=data_augmentation_chain2,
**common_args
)
elif dataset_type == 'widerperson':
if annotations_path is None:
annotations_path = 'train'
if val_annotations_path is None:
val_annotations_path = 'val'
from generators.widerperson import PersonGenerator
train_generator = PersonGenerator(
data_root,
annotations_path,
transformations=data_augmentation_chain,
**common_args
)
validation_generator = PersonGenerator(
data_root_val,
val_annotations_path,
shuffle_groups=False,
transformations=data_augmentation_chain2,
**common_args
)
elif dataset_type == 'csv':
# class_name, class_id
from generators.csv_ import CSVGenerator
# classes_path = {'1':0}
assert annotations_path is not None
if classes_path is None:
import warnings
warnings.warn(
'no class file provided, use {int:int}')
classes_path = dict(zip(list(range(200)), list(range(200))))
train_generator = CSVGenerator(
annotations_path,
classes_path,
base_dir=data_root,
transformations=data_augmentation_chain,
**common_args
)
if val_annotations_path:
validation_generator = CSVGenerator(
val_annotations_path,
classes_path,
base_dir=data_root_val,
shuffle_groups=False,
transformations=data_augmentation_chain2,
**common_args
)
else:
validation_generator = None
elif dataset_type == 'csvs':
# class_name, class_id
from generators.csv_multi import CSVGenerator
# classes_path = {'1':0}
assert annotations_path is not None
assert classes_path is not None
train_generator = CSVGenerator(
annotations_path,
classes_path,
base_dirs=data_root,
transformations=data_augmentation_chain,
**common_args
)
if val_annotations_path:
validation_generator = CSVGenerator(
val_annotations_path,
classes_path,
base_dirs=data_root_val,
shuffle_groups=False,
transformations=data_augmentation_chain2,
**common_args
)
else:
validation_generator = None
elif dataset_type == 'coco':
from generators.coco import CocoGenerator
if annotations_path is None:
annotations_path = 'train2017'
if val_annotations_path is None:
val_annotations_path = 'val2017'
train_generator = CocoGenerator(
data_root,
annotations_path,
transformations=data_augmentation_chain,
**common_args
)
validation_generator = CocoGenerator(
data_root_val,
val_annotations_path,
shuffle_groups=False,
transformations=data_augmentation_chain2,
**common_args
)
else:
raise ValueError('Invalid data type received: {}'.format(dataset_type))
return train_generator, validation_generator