import torch def check_num_classes_from_pth(pth_path): checkpoint = torch.load(pth_path, map_location='cpu') if 'state_dict' not in checkpoint: print("❌ 找不到 state_dict") return state_dict = checkpoint['state_dict'] weight_key = 'decode_head.conv_seg.weight' if weight_key in state_dict: weight = state_dict[weight_key] num_classes = weight.shape[0] print(f"✅ 類別數: {num_classes}") if num_classes == 19: print("⚠️ 這是 Cityscapes 模型 (19 類)") elif num_classes == 4: print("✅ 這是 GolfDataset 模型 (4 類)") else: print("❓ 非常規類別數,請自行確認資料與 config") else: print(f"❌ 找不到分類層: {weight_key}") if __name__ == '__main__': pth_path = r'C:\Users\rd_de\kneronstdc\work_dirs\meconfig\latest.pth' check_num_classes_from_pth(pth_path)