import torch def check_pth_num_classes(pth_path): checkpoint = torch.load(pth_path, map_location='cpu') if 'state_dict' not in checkpoint: print("❌ 找不到 state_dict,這可能不是 MMSegmentation 的模型檔") return state_dict = checkpoint['state_dict'] # 找出 decode head 最後一層分類器的 weight tensor num_classes = None for k in state_dict.keys(): if 'decode_head' in k and 'weight' in k and 'decode_head.classifier' in k: weight_tensor = state_dict[k] num_classes = weight_tensor.shape[0] print(f"✅ 檢查到類別數: {num_classes}") break if num_classes is None: print("⚠️ 無法判斷類別數,可能模型架構非標準格式") else: if num_classes == 19: print("⚠️ 這是 Cityscapes 預設模型 (19 類)") elif num_classes == 4: print("✅ 這是 GolfDataset 自訂模型 (4 類)") else: print("❓ 類別數異常,請確認訓練資料與 config 設定是否一致") if __name__ == '__main__': pth_path = r'C:\Users\rd_de\kneronstdc\work_dirs\meconfig\latest.pth' check_pth_num_classes(pth_path)