234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
#!/usr/bin/env python
|
|
|
|
"""
|
|
Copyright 2017-2018 Fizyr (https://fizyr.com)
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import argparse
|
|
import os
|
|
import sys
|
|
import warnings
|
|
import yaml
|
|
|
|
import keras
|
|
import keras.preprocessing.image
|
|
import tensorflow as tf
|
|
from datetime import date
|
|
|
|
import losses
|
|
from models_bd.fcos import fcos
|
|
from callbacks import create_callbacks
|
|
|
|
from generators import create_generators
|
|
from utils.anchors import make_shapes_callback
|
|
from utils.anchors import AnchorParameters
|
|
from utils.config import read_config_file, parse_anchor_parameters
|
|
from utils.keras_version import check_keras_version
|
|
from utils.model import freeze as freeze_model
|
|
from utils.transform import random_transform_generator
|
|
from utils.image import random_visual_effect_generator
|
|
from csv_preprocess import prepare_txt
|
|
|
|
def get_session():
|
|
"""
|
|
Construct a modified tf session.
|
|
"""
|
|
config = tf.ConfigProto()
|
|
config.gpu_options.allow_growth = True
|
|
return tf.Session(config=config)
|
|
|
|
def makedirs(path):
|
|
# Intended behavior: try to create the directory,
|
|
# pass if the directory exists already, fails otherwise.
|
|
# Meant for Python 2.7/3.n compatibility.
|
|
try:
|
|
os.makedirs(path)
|
|
except OSError:
|
|
if not os.path.isdir(path):
|
|
raise
|
|
|
|
def check_args(parsed_args):
|
|
""" Function to check for inherent contradictions within parsed arguments.
|
|
Intended to raise errors prior to backend initialisation.
|
|
|
|
Args
|
|
parsed_args: parser.parse_args()
|
|
|
|
Returns
|
|
parsed_args
|
|
"""
|
|
|
|
return parsed_args
|
|
|
|
|
|
def parse_args(args):
|
|
"""
|
|
Parse the arguments.
|
|
"""
|
|
today = str(date.today())
|
|
parser = argparse.ArgumentParser(description='Simple training script for training a FCOS network.')
|
|
|
|
parser.add_argument('--data', help='Path to the data yaml file. (located under ./data/).')
|
|
|
|
parser.add_argument('--snapshot', help='Resume training from a snapshot.')
|
|
parser.add_argument('--backbone', help='Backbone model used by retinanet.', default='resnet50', type=str)
|
|
parser.add_argument('--fpn', help='fpn model', default='simple', type=str)
|
|
parser.add_argument('--reg-func', help='regression func', default='linear', type=str)
|
|
parser.add_argument('--stage', help='num of stage', default=3, type=int)
|
|
parser.add_argument('--head-type', help='head type', default='simple', type=str)
|
|
parser.add_argument('--centerness-pos', help='centerness branch pos', default='reg', type=str)
|
|
parser.add_argument('--batch-size', help='Size of the batches.', default=4, type=int)
|
|
parser.add_argument('--gpu', help='Id of the GPU to use (as reported by nvidia-smi).')
|
|
parser.add_argument('--epochs', help='Number of epochs to train.', type=int, default=100)
|
|
parser.add_argument('--steps', help='Number of steps per epoch.', type=int, default=5000)
|
|
parser.add_argument('--lr', help='Learning rate.', type=float, default=1e-4)
|
|
parser.add_argument('--snapshot-path',
|
|
help='Path to store snapshots of models during training (defaults to \'snapshots\')',
|
|
default='snapshots/{}'.format(today))
|
|
|
|
parser.add_argument('--freeze-backbone', help='Freeze training of backbone layers.', action='store_true')
|
|
|
|
parser.add_argument('--input-size', help='Rescale the image if the largest side is larger than max_side.',
|
|
type=int, default=512)
|
|
parser.add_argument('--compute-val-loss', help='Compute validation loss during training', dest='compute_val_loss',
|
|
action='store_true')
|
|
|
|
print(vars(parser.parse_args(args)))
|
|
return check_args(parser.parse_args(args))
|
|
|
|
|
|
def main(args=None):
|
|
# parse arguments
|
|
if args is None:
|
|
args = sys.argv[1:]
|
|
args = parse_args(args)
|
|
|
|
# get dataset information
|
|
with open(args.data) as f:
|
|
data_dict = yaml.load(f, Loader=yaml.FullLoader) # data dict
|
|
|
|
args.dataset_type = data_dict.get("dataset_type", 'csv')
|
|
|
|
if args.dataset_type == 'csv':
|
|
args.data_root = data_dict["train"]
|
|
args.data_root_val = data_dict["val"]
|
|
args.classes_path = data_dict["names"]
|
|
args.annotations_path = prepare_txt(data_dict["train"], data_dict["names"], trainset = True)
|
|
args.val_annotations_path = prepare_txt(data_dict["val"], data_dict["names"], trainset = False)
|
|
elif args.dataset_type == 'coco':
|
|
args.data_root = data_dict["data_root"]
|
|
elif args.dataset_type == 'pascal':
|
|
args.data_root = data_dict["data_root"]
|
|
args.annotations_path = data_dict["train"]
|
|
args.val_annotations_path = data_dict["val"]
|
|
else:
|
|
print('Unsupported dataset type.')
|
|
return
|
|
|
|
|
|
# make sure keras is the minimum required version
|
|
check_keras_version()
|
|
|
|
# optionally choose specific GPU
|
|
if args.gpu:
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
|
|
keras.backend.tensorflow_backend.set_session(get_session())
|
|
|
|
n_stage = args.stage
|
|
if n_stage==5:
|
|
anchor_param = AnchorParameters.default
|
|
elif n_stage==3:
|
|
anchor_param = AnchorParameters(strides=[8, 16, 32],
|
|
interest_sizes=[
|
|
[-1, 64],
|
|
[64, 128],
|
|
[128, 1e8],
|
|
])
|
|
else:
|
|
assert 0
|
|
|
|
# create the generators
|
|
if args.dataset_type == 'csv':
|
|
train_generator, validation_generator = create_generators(input_size=args.input_size,
|
|
data_root=args.data_root,
|
|
data_root_val=args.data_root_val,
|
|
annotations_path=args.annotations_path,
|
|
val_annotations_path=args.val_annotations_path,
|
|
dataset_type=args.dataset_type,
|
|
batch_size=args.batch_size,
|
|
anchor_param=anchor_param,
|
|
classes_path=args.classes_path)
|
|
else:
|
|
train_generator, validation_generator = create_generators(input_size=args.input_size,
|
|
data_root=args.data_root,
|
|
annotations_path=args.annotations_path,
|
|
val_annotations_path=args.val_annotations_path,
|
|
dataset_type=args.dataset_type,
|
|
batch_size=args.batch_size,
|
|
anchor_param=anchor_param)
|
|
model, prediction_model, debug_model = fcos(backbone=args.backbone,
|
|
num_classes=train_generator.num_classes(),
|
|
input_size=args.input_size,
|
|
weights=args.snapshot,
|
|
freeze_backbone=args.freeze_backbone,
|
|
fpn_type=args.fpn,
|
|
n_stage=n_stage,
|
|
mapping_func=args.reg_func,
|
|
head_type=args.head_type,
|
|
centerness_pos=args.centerness_pos)
|
|
|
|
|
|
training_model = model
|
|
training_model.compile(
|
|
loss={
|
|
'regression': losses.giou,
|
|
'classification': losses.focal(),
|
|
'centerness': losses.bce(),
|
|
},
|
|
optimizer=keras.optimizers.Adam(lr=args.lr),
|
|
loss_weights={'regression':2, 'classification':1, 'centerness':0.7}
|
|
)
|
|
|
|
|
|
# create the callbacks
|
|
callbacks = create_callbacks(
|
|
debug_model,
|
|
prediction_model,
|
|
validation_generator,
|
|
args.snapshot_path,
|
|
dataset_type=args.dataset_type,
|
|
backbone=args.backbone,
|
|
fpn=args.fpn,
|
|
n_stage=n_stage
|
|
)
|
|
|
|
if not args.compute_val_loss:
|
|
validation_generator = None
|
|
|
|
# start training
|
|
trained_model = training_model.fit_generator(
|
|
generator=train_generator,
|
|
initial_epoch=0,
|
|
steps_per_epoch=args.steps,
|
|
epochs=args.epochs,
|
|
verbose=1,
|
|
callbacks=callbacks,
|
|
validation_data=validation_generator
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|