# Copyright (c) OpenMMLab. All rights reserved. # ---------------- 模型設定 ---------------- # norm_cfg = dict(type='BN', requires_grad=True) model = dict( type='EncoderDecoder', pretrained=None, backbone=dict( type='STDCContextPathNet', backbone_cfg=dict( type='STDCNet', stdc_type='STDCNet1', in_channels=3, channels=(32, 64, 256, 512, 1024), bottleneck_type='cat', num_convs=4, norm_cfg=norm_cfg, act_cfg=dict(type='ReLU'), with_final_conv=False, init_cfg=dict( type='Pretrained', checkpoint='https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/stdc/stdc1_20220308-5368626c.pth' ) ), last_in_channels=(1024, 512), out_channels=128, ffm_cfg=dict(in_channels=384, out_channels=256, scale_factor=4) ), decode_head=dict( type='FCNHead', in_channels=256, channels=256, num_convs=1, num_classes=4, # ✅ 四類 in_index=3, concat_input=False, dropout_ratio=0.1, norm_cfg=norm_cfg, align_corners=True, sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) ), auxiliary_head=[ dict( type='FCNHead', in_channels=128, channels=64, num_convs=1, num_classes=4, # ✅ in_index=2, norm_cfg=norm_cfg, concat_input=False, align_corners=False, sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) ), dict( type='FCNHead', in_channels=128, channels=64, num_convs=1, num_classes=4, # ✅ in_index=1, norm_cfg=norm_cfg, concat_input=False, align_corners=False, sampler=dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000), loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0) ), dict( type='STDCHead', in_channels=256, channels=64, num_convs=1, num_classes=4, # ✅ 最重要 boundary_threshold=0.1, in_index=0, norm_cfg=norm_cfg, concat_input=False, align_corners=True, loss_decode=[ dict( type='CrossEntropyLoss', loss_name='loss_ce', use_sigmoid=True, loss_weight=1.0), dict( type='DiceLoss', loss_name='loss_dice', loss_weight=1.0) ] ) ], train_cfg=dict(), test_cfg=dict(mode='whole') ) # ---------------- 資料集設定 ---------------- # dataset_type = 'GolfDataset' data_root = 'data/cityscapes/' img_norm_cfg = dict( mean=[128., 128., 128.], std=[256., 256., 256.], to_rgb=True) crop_size = (512, 1024) train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations'), dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(1024, 512), flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( samples_per_gpu=2, workers_per_gpu=2, train=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/train', ann_dir='gtFine/train', pipeline=train_pipeline ), val=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/val', ann_dir='gtFine/val', pipeline=test_pipeline ), test=dict( type=dataset_type, data_root=data_root, img_dir='leftImg8bit/test', ann_dir='gtFine/test', pipeline=test_pipeline ) ) # ---------------- 額外設定 ---------------- # log_config = dict( interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)]) checkpoint_config = dict(by_epoch=False, interval=1000) evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) optimizer_config = dict() lr_config = dict( policy='poly', power=0.9, min_lr=0.0001, by_epoch=False, warmup='linear', warmup_iters=1000) runner = dict(type='IterBasedRunner', max_iters=20000) cudnn_benchmark = True dist_params = dict(backend='nccl') log_level = 'INFO' load_from = None resume_from = None workflow = [('train', 1)] work_dir = './work_dirs/kn_stdc1_golf4class' gpu_ids = [0] # ✅ 可選:僅供視覺化或 post-processing 用,不會傳給 dataset classes = ('car', 'grass', 'people', 'road') palette = [ [246, 14, 135], # car [233, 81, 78], # grass [220, 148, 21], # people [207, 215, 220], # road ]