117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
from keras.callbacks import *
|
|
import keras
|
|
import os
|
|
|
|
class RedirectModel(keras.callbacks.Callback):
|
|
"""
|
|
Callback which wraps another callback, but executed on a different model.
|
|
|
|
```python
|
|
model = keras.models.load_model('model.h5')
|
|
model_checkpoint = ModelCheckpoint(filepath='snapshot.h5')
|
|
parallel_model = multi_gpu_model(model, gpus=2)
|
|
parallel_model.fit(X_train, Y_train, callbacks=[RedirectModel(model_checkpoint, model)])
|
|
```
|
|
|
|
Args
|
|
callback : callback to wrap.
|
|
model : model to use when executing callbacks.
|
|
"""
|
|
|
|
def __init__(self,
|
|
callback,
|
|
model):
|
|
super(RedirectModel, self).__init__()
|
|
|
|
self.callback = callback
|
|
self.redirect_model = model
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self.callback.on_epoch_begin(epoch, logs=logs)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
self.callback.on_epoch_end(epoch, logs=logs)
|
|
|
|
def on_batch_begin(self, batch, logs=None):
|
|
self.callback.on_batch_begin(batch, logs=logs)
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
self.callback.on_batch_end(batch, logs=logs)
|
|
|
|
def on_train_begin(self, logs=None):
|
|
# overwrite the model with our custom model
|
|
self.callback.set_model(self.redirect_model)
|
|
|
|
self.callback.on_train_begin(logs=logs)
|
|
|
|
def on_train_end(self, logs=None):
|
|
self.callback.on_train_end(logs=logs)
|
|
|
|
|
|
def scheduler(epoch, lr):
|
|
if epoch < 40:
|
|
return 1e-5
|
|
else:
|
|
return 1e-6
|
|
|
|
|
|
def create_callbacks(training_model, prediction_model, validation_generator, snapshot_path, backbone,fpn,n_stage,
|
|
evaluation=True, dataset_type='voc', snapshots=True):
|
|
"""
|
|
Creates the callbacks to use during training.
|
|
|
|
Args
|
|
training_model: The model that is used for training.
|
|
prediction_model: The model that should be used for validation.
|
|
validation_generator: The generator for creating validation data.
|
|
args: parseargs args object.
|
|
|
|
Returns:
|
|
A list of callbacks used for training.
|
|
"""
|
|
callbacks = []
|
|
|
|
|
|
if evaluation and validation_generator:
|
|
if dataset_type == 'coco':
|
|
from eval.coco import CocoEval
|
|
# use prediction model for evaluation
|
|
evaluation = CocoEval(validation_generator, prediction_model)
|
|
else:
|
|
from eval.pascal import Evaluate
|
|
evaluation = Evaluate(validation_generator, prediction_model)
|
|
callbacks.append(evaluation)
|
|
|
|
# save the model
|
|
if snapshots:
|
|
# ensure directory created first; otherwise h5py will error after epoch.
|
|
if not os.path.exists(snapshot_path):
|
|
os.makedirs(snapshot_path)
|
|
checkpoint = ModelCheckpoint(
|
|
os.path.join(
|
|
snapshot_path,
|
|
'{dataset_type}_{backbone}_{fpn}_{n_stage}_{{epoch:02d}}.h5'.format(dataset_type=dataset_type,
|
|
fpn=fpn,
|
|
n_stage=n_stage,
|
|
backbone=backbone)
|
|
),
|
|
verbose=1,
|
|
save_best_only=True,
|
|
monitor="mAP",
|
|
mode='max'
|
|
)
|
|
checkpoint = RedirectModel(checkpoint, training_model)
|
|
callbacks.append(checkpoint)
|
|
|
|
# To do !
|
|
early_stopping = EarlyStopping(
|
|
monitor='mAP',
|
|
min_delta=0.0,
|
|
mode = 'max',
|
|
patience=5,
|
|
verbose=1,
|
|
restore_best_weights = True)
|
|
|
|
callbacks.append(early_stopping)
|
|
return callbacks
|