Add runner type (#118)
* Add runner_type option * pre-commit * Fix max_iters * Add by_epoch to EvalHook * Add test_eval_hook for epoch runner * Remove runner-type arg from tools/train * Add missing every_n_iters check for epoch mode * Bump mmcv min version * Use build_runner * Use interval in tests * Update test_eval_hook.py * Use every_n_epochs instead of every_n_iters. Update DistEvalHook * Add test_dist_eval_hook_epoch * Fix tests * Add DeprecationWarning * Update docs * Replace DeprecationWarning with UserWarning
This commit is contained in:
parent
3bdc276888
commit
e384ef578a
@ -4,6 +4,6 @@ optimizer_config = dict()
|
|||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||||
# runtime settings
|
# runtime settings
|
||||||
total_iters = 160000
|
runner = dict(type='IterBasedRunner', max_iters=160000)
|
||||||
checkpoint_config = dict(by_epoch=False, interval=16000)
|
checkpoint_config = dict(by_epoch=False, interval=16000)
|
||||||
evaluation = dict(interval=16000, metric='mIoU')
|
evaluation = dict(interval=16000, metric='mIoU')
|
||||||
|
|||||||
@ -4,6 +4,6 @@ optimizer_config = dict()
|
|||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||||
# runtime settings
|
# runtime settings
|
||||||
total_iters = 20000
|
runner = dict(type='IterBasedRunner', max_iters=20000)
|
||||||
checkpoint_config = dict(by_epoch=False, interval=2000)
|
checkpoint_config = dict(by_epoch=False, interval=2000)
|
||||||
evaluation = dict(interval=2000, metric='mIoU')
|
evaluation = dict(interval=2000, metric='mIoU')
|
||||||
|
|||||||
@ -4,6 +4,6 @@ optimizer_config = dict()
|
|||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||||
# runtime settings
|
# runtime settings
|
||||||
total_iters = 40000
|
runner = dict(type='IterBasedRunner', max_iters=40000)
|
||||||
checkpoint_config = dict(by_epoch=False, interval=4000)
|
checkpoint_config = dict(by_epoch=False, interval=4000)
|
||||||
evaluation = dict(interval=4000, metric='mIoU')
|
evaluation = dict(interval=4000, metric='mIoU')
|
||||||
|
|||||||
@ -4,6 +4,6 @@ optimizer_config = dict()
|
|||||||
# learning policy
|
# learning policy
|
||||||
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
|
||||||
# runtime settings
|
# runtime settings
|
||||||
total_iters = 80000
|
runner = dict(type='IterBasedRunner', max_iters=80000)
|
||||||
checkpoint_config = dict(by_epoch=False, interval=8000)
|
checkpoint_config = dict(by_epoch=False, interval=8000)
|
||||||
evaluation = dict(interval=8000, metric='mIoU')
|
evaluation = dict(interval=8000, metric='mIoU')
|
||||||
|
|||||||
@ -226,7 +226,7 @@ dist_params = dict(backend='nccl') # Parameters to setup distributed training,
|
|||||||
log_level = 'INFO' # The level of logging.
|
log_level = 'INFO' # The level of logging.
|
||||||
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
|
load_from = None # load models as a pre-trained model from a given path. This will not resume training.
|
||||||
resume_from = None # Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved.
|
resume_from = None # Resume checkpoints from a given path, the training will be resumed from the iteration when the checkpoint's is saved.
|
||||||
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 40000 iterations according to the total_iters.
|
workflow = [('train', 1)] # Workflow for runner. [('train', 1)] means there is only one workflow and the workflow named 'train' is executed once. The workflow trains the model by 40000 iterations according to the `runner.max_iters`.
|
||||||
cudnn_benchmark = True # Whether use cudnn_benchmark to speed up, which is fast for fixed input size.
|
cudnn_benchmark = True # Whether use cudnn_benchmark to speed up, which is fast for fixed input size.
|
||||||
optimizer = dict( # Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
|
optimizer = dict( # Config used to build optimizer, support all the optimizers in PyTorch whose arguments are also the same as those in PyTorch
|
||||||
type='SGD', # Type of optimizers, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details
|
type='SGD', # Type of optimizers, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/optimizer/default_constructor.py#L13 for more details
|
||||||
@ -239,7 +239,9 @@ lr_config = dict(
|
|||||||
power=0.9, # The power of polynomial decay.
|
power=0.9, # The power of polynomial decay.
|
||||||
min_lr=0.0001, # The minimum learning rate to stable the training.
|
min_lr=0.0001, # The minimum learning rate to stable the training.
|
||||||
by_epoch=False) # Whethe count by epoch or not.
|
by_epoch=False) # Whethe count by epoch or not.
|
||||||
total_iters = 40000 # Total number of iterations.
|
runner = dict(
|
||||||
|
type='IterBasedRunner', # Type of runner to use (i.e. IterBasedRunner or EpochBasedRunner)
|
||||||
|
max_iters=40000) # Total number of iterations. For EpochBasedRunner use `max_epochs`
|
||||||
checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
|
checkpoint_config = dict( # Config to set the checkpoint hook, Refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/checkpoint.py for implementation.
|
||||||
by_epoch=False, # Whethe count by epoch or not.
|
by_epoch=False, # Whethe count by epoch or not.
|
||||||
interval=4000) # The save interval.
|
interval=4000) # The save interval.
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import mmcv
|
|||||||
|
|
||||||
from .version import __version__, version_info
|
from .version import __version__, version_info
|
||||||
|
|
||||||
MMCV_MIN = '1.1.2'
|
MMCV_MIN = '1.1.4'
|
||||||
MMCV_MAX = '1.2.0'
|
MMCV_MAX = '1.2.0'
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
import random
|
import random
|
||||||
|
import warnings
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
|
||||||
from mmcv.runner import IterBasedRunner, build_optimizer
|
from mmcv.runner import build_optimizer, build_runner
|
||||||
|
|
||||||
from mmseg.core import DistEvalHook, EvalHook
|
from mmseg.core import DistEvalHook, EvalHook
|
||||||
from mmseg.datasets import build_dataloader, build_dataset
|
from mmseg.datasets import build_dataloader, build_dataset
|
||||||
@ -70,13 +71,21 @@ def train_segmentor(model,
|
|||||||
# build runner
|
# build runner
|
||||||
optimizer = build_optimizer(model, cfg.optimizer)
|
optimizer = build_optimizer(model, cfg.optimizer)
|
||||||
|
|
||||||
runner = IterBasedRunner(
|
if cfg.get('runner') is None:
|
||||||
model=model,
|
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
|
||||||
batch_processor=None,
|
warnings.warn(
|
||||||
optimizer=optimizer,
|
'config is now expected to have a `runner` section, '
|
||||||
work_dir=cfg.work_dir,
|
'please set `runner` in your config.', UserWarning)
|
||||||
logger=logger,
|
|
||||||
meta=meta)
|
runner = build_runner(
|
||||||
|
cfg.runner,
|
||||||
|
default_args=dict(
|
||||||
|
model=model,
|
||||||
|
batch_processor=None,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=cfg.work_dir,
|
||||||
|
logger=logger,
|
||||||
|
meta=meta))
|
||||||
|
|
||||||
# register hooks
|
# register hooks
|
||||||
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
|
||||||
@ -96,6 +105,7 @@ def train_segmentor(model,
|
|||||||
dist=distributed,
|
dist=distributed,
|
||||||
shuffle=False)
|
shuffle=False)
|
||||||
eval_cfg = cfg.get('evaluation', {})
|
eval_cfg = cfg.get('evaluation', {})
|
||||||
|
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
|
||||||
eval_hook = DistEvalHook if distributed else EvalHook
|
eval_hook = DistEvalHook if distributed else EvalHook
|
||||||
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
|
||||||
|
|
||||||
@ -103,4 +113,4 @@ def train_segmentor(model,
|
|||||||
runner.resume(cfg.resume_from)
|
runner.resume(cfg.resume_from)
|
||||||
elif cfg.load_from:
|
elif cfg.load_from:
|
||||||
runner.load_checkpoint(cfg.load_from)
|
runner.load_checkpoint(cfg.load_from)
|
||||||
runner.run(data_loaders, cfg.workflow, cfg.total_iters)
|
runner.run(data_loaders, cfg.workflow)
|
||||||
|
|||||||
@ -12,17 +12,27 @@ class EvalHook(Hook):
|
|||||||
interval (int): Evaluation interval (by epochs). Default: 1.
|
interval (int): Evaluation interval (by epochs). Default: 1.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dataloader, interval=1, **eval_kwargs):
|
def __init__(self, dataloader, interval=1, by_epoch=False, **eval_kwargs):
|
||||||
if not isinstance(dataloader, DataLoader):
|
if not isinstance(dataloader, DataLoader):
|
||||||
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
raise TypeError('dataloader must be a pytorch DataLoader, but got '
|
||||||
f'{type(dataloader)}')
|
f'{type(dataloader)}')
|
||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
|
self.by_epoch = by_epoch
|
||||||
self.eval_kwargs = eval_kwargs
|
self.eval_kwargs = eval_kwargs
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def after_train_iter(self, runner):
|
||||||
"""After train epoch hook."""
|
"""After train epoch hook."""
|
||||||
if not self.every_n_iters(runner, self.interval):
|
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||||
|
return
|
||||||
|
from mmseg.apis import single_gpu_test
|
||||||
|
runner.log_buffer.clear()
|
||||||
|
results = single_gpu_test(runner.model, self.dataloader, show=False)
|
||||||
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
|
def after_train_epoch(self, runner):
|
||||||
|
"""After train epoch hook."""
|
||||||
|
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||||
return
|
return
|
||||||
from mmseg.apis import single_gpu_test
|
from mmseg.apis import single_gpu_test
|
||||||
runner.log_buffer.clear()
|
runner.log_buffer.clear()
|
||||||
@ -54,6 +64,7 @@ class DistEvalHook(EvalHook):
|
|||||||
dataloader,
|
dataloader,
|
||||||
interval=1,
|
interval=1,
|
||||||
gpu_collect=False,
|
gpu_collect=False,
|
||||||
|
by_epoch=False,
|
||||||
**eval_kwargs):
|
**eval_kwargs):
|
||||||
if not isinstance(dataloader, DataLoader):
|
if not isinstance(dataloader, DataLoader):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -62,11 +73,27 @@ class DistEvalHook(EvalHook):
|
|||||||
self.dataloader = dataloader
|
self.dataloader = dataloader
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.gpu_collect = gpu_collect
|
self.gpu_collect = gpu_collect
|
||||||
|
self.by_epoch = by_epoch
|
||||||
self.eval_kwargs = eval_kwargs
|
self.eval_kwargs = eval_kwargs
|
||||||
|
|
||||||
def after_train_iter(self, runner):
|
def after_train_iter(self, runner):
|
||||||
"""After train epoch hook."""
|
"""After train epoch hook."""
|
||||||
if not self.every_n_iters(runner, self.interval):
|
if self.by_epoch or not self.every_n_iters(runner, self.interval):
|
||||||
|
return
|
||||||
|
from mmseg.apis import multi_gpu_test
|
||||||
|
runner.log_buffer.clear()
|
||||||
|
results = multi_gpu_test(
|
||||||
|
runner.model,
|
||||||
|
self.dataloader,
|
||||||
|
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
|
||||||
|
gpu_collect=self.gpu_collect)
|
||||||
|
if runner.rank == 0:
|
||||||
|
print('\n')
|
||||||
|
self.evaluate(runner, results)
|
||||||
|
|
||||||
|
def after_train_epoch(self, runner):
|
||||||
|
"""After train epoch hook."""
|
||||||
|
if not self.by_epoch or not self.every_n_epochs(runner, self.interval):
|
||||||
return
|
return
|
||||||
from mmseg.apis import multi_gpu_test
|
from mmseg.apis import multi_gpu_test
|
||||||
runner.log_buffer.clear()
|
runner.log_buffer.clear()
|
||||||
|
|||||||
@ -38,7 +38,7 @@ class ExampleModel(nn.Module):
|
|||||||
return dict(loss=loss)
|
return dict(loss=loss)
|
||||||
|
|
||||||
|
|
||||||
def test_eval_hook():
|
def test_iter_eval_hook():
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
test_dataset = ExampleModel()
|
test_dataset = ExampleModel()
|
||||||
data_loader = [
|
data_loader = [
|
||||||
@ -75,6 +75,43 @@ def test_eval_hook():
|
|||||||
logger=runner.logger)
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
|
def test_epoch_eval_hook():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
EvalHook(data_loader, by_epoch=True)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test EvalHook with interval
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = EvalHook(data_loader, by_epoch=True, interval=2)
|
||||||
|
runner = mmcv.runner.EpochBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger())
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)], 2)
|
||||||
|
test_dataset.evaluate.assert_called_once_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
|
||||||
results = single_gpu_test(model, data_loader)
|
results = single_gpu_test(model, data_loader)
|
||||||
return results
|
return results
|
||||||
@ -116,3 +153,41 @@ def test_dist_eval_hook():
|
|||||||
runner.run([loader], [('train', 1)], 1)
|
runner.run([loader], [('train', 1)], 1)
|
||||||
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||||
logger=runner.logger)
|
logger=runner.logger)
|
||||||
|
|
||||||
|
|
||||||
|
@patch('mmseg.apis.multi_gpu_test', multi_gpu_test)
|
||||||
|
def test_dist_eval_hook_epoch():
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
test_dataset = ExampleModel()
|
||||||
|
data_loader = [
|
||||||
|
DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
sampler=None,
|
||||||
|
num_worker=0,
|
||||||
|
shuffle=False)
|
||||||
|
]
|
||||||
|
DistEvalHook(data_loader)
|
||||||
|
|
||||||
|
test_dataset = ExampleDataset()
|
||||||
|
test_dataset.evaluate = MagicMock(return_value=dict(test='success'))
|
||||||
|
loader = DataLoader(test_dataset, batch_size=1)
|
||||||
|
model = ExampleModel()
|
||||||
|
data_loader = DataLoader(
|
||||||
|
test_dataset, batch_size=1, sampler=None, num_workers=0, shuffle=False)
|
||||||
|
optim_cfg = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||||
|
optimizer = obj_from_dict(optim_cfg, torch.optim,
|
||||||
|
dict(params=model.parameters()))
|
||||||
|
|
||||||
|
# test DistEvalHook
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
eval_hook = DistEvalHook(data_loader, by_epoch=True, interval=2)
|
||||||
|
runner = mmcv.runner.EpochBasedRunner(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
work_dir=tmpdir,
|
||||||
|
logger=logging.getLogger())
|
||||||
|
runner.register_hook(eval_hook)
|
||||||
|
runner.run([loader], [('train', 1)], 2)
|
||||||
|
test_dataset.evaluate.assert_called_with([torch.tensor([1])],
|
||||||
|
logger=runner.logger)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user