STDC/data/seg2city7class.py
charlie880624 ac3c30a1d2
Some checks failed
build / build_cpu (3.7, 1.5.1, torch1.5, 0.6.1) (push) Has been cancelled
build / build_cpu (3.7, 1.6.0, torch1.6, 0.7.0) (push) Has been cancelled
build / build_cpu (3.7, 1.7.0, torch1.7, 0.8.1) (push) Has been cancelled
build / build_cpu (3.7, 1.8.0, torch1.8, 0.9.0) (push) Has been cancelled
build / build_cpu (3.7, 1.9.0, torch1.9, 0.10.0) (push) Has been cancelled
build / build_cuda101 (3.7, 1.5.1+cu101, torch1.5, 0.6.1+cu101) (push) Has been cancelled
build / build_cuda101 (3.7, 1.6.0+cu101, torch1.6, 0.7.0+cu101) (push) Has been cancelled
build / build_cuda101 (3.7, 1.7.0+cu101, torch1.7, 0.8.1+cu101) (push) Has been cancelled
build / build_cuda101 (3.7, 1.8.0+cu101, torch1.8, 0.9.0+cu101) (push) Has been cancelled
build / build_cuda102 (3.6, 1.9.0+cu102, torch1.9, 0.10.0+cu102) (push) Has been cancelled
build / build_cuda102 (3.7, 1.9.0+cu102, torch1.9, 0.10.0+cu102) (push) Has been cancelled
build / build_cuda102 (3.8, 1.9.0+cu102, torch1.9, 0.10.0+cu102) (push) Has been cancelled
build / build_cuda102 (3.9, 1.9.0+cu102, torch1.9, 0.10.0+cu102) (push) Has been cancelled
build / test_windows (windows-2022, cpu, 3.8) (push) Has been cancelled
build / test_windows (windows-2022, cu111, 3.8) (push) Has been cancelled
deploy / build-n-publish (push) Has been cancelled
lint / lint (push) Has been cancelled
add data folder and update gitignore
2026-03-18 17:57:51 +08:00

100 lines
3.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
# ✅ 設定各資料夾的來源與輸出
datasets = [
{
"name": "train",
"input_dir": r"C:\Users\rd_de\kneronstdc\data\06-03_danger_object_segmentation\train",
"output_img_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\leftImg8bit\train",
"output_mask_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\gtFine\train",
},
{
"name": "val",
"input_dir": r"C:\Users\rd_de\kneronstdc\data\06-03_danger_object_segmentation\valid",
"output_img_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\leftImg8bit\val",
"output_mask_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\gtFine\val",
},
{
"name": "test",
"input_dir": r"C:\Users\rd_de\kneronstdc\data\06-03_danger_object_segmentation\test",
"output_img_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\leftImg8bit\test",
"output_mask_dir": r"C:\Users\rd_de\kneronstdc\data\cityscapes\gtFine\test",
}
]
# ✅ 建立 label ID 對應表(原始 ➝ 訓練用 label
label_mapping = {
1: 0, # bunker
2: 1, # car
3: 2, # grass
4: 3, # greenery
5: 4, # person
6: 5, # road
7: 6 # tree
}
for dataset in datasets:
name = dataset["name"]
input_dir = dataset["input_dir"]
output_img_dir = dataset["output_img_dir"]
output_mask_dir = dataset["output_mask_dir"]
os.makedirs(output_img_dir, exist_ok=True)
os.makedirs(output_mask_dir, exist_ok=True)
# ✅ 嘗試讀取 _classes.csv
csv_path = os.path.join(input_dir, "_classes.csv")
if not os.path.exists(csv_path):
print(f"❌ 缺少 _classes.csv: {csv_path}")
continue
df = pd.read_csv(csv_path)
valid_labels = set(df.iloc[:, 0].values)
print(f"\n📂 資料集: {name}")
print(f"✅ 合法類別 ID: {valid_labels}")
for file in tqdm(os.listdir(input_dir), desc=f"轉換 {name}"):
if not file.endswith("_mask.png"):
continue
mask_path = os.path.join(input_dir, file)
image_name = file.replace("_mask.png", ".jpg")
image_path = os.path.join(input_dir, image_name)
if not os.path.exists(image_path):
print(f"⚠️ 找不到對應圖片: {image_name}")
continue
img = cv2.imread(image_path)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if img is None or mask is None:
print(f"❌ 無法讀取圖像或 mask: {file}")
continue
labels = set(np.unique(mask))
valid_ids = set(label_mapping.keys()) | {0}
if not labels.issubset(valid_ids):
print(f"❌ 非法類別出現在: {file}, labels={labels}")
continue
# ✅ 映射 label0 ➝ 255ignore其餘照 mapping
remapped_mask = np.full_like(mask, 255, dtype=np.uint8) # 預設全部是 ignore
for old_id, new_id in label_mapping.items():
remapped_mask[mask == old_id] = new_id
# ✅ 命名符合 Cityscapes 格式
stem = Path(file).stem.replace("_mask", "")
out_img_path = os.path.join(output_img_dir, f"{stem}_leftImg8bit.png")
out_mask_path = os.path.join(output_mask_dir, f"{stem}_gtFine_labelIds.png")
cv2.imwrite(out_img_path, img)
cv2.imwrite(out_mask_path, remapped_mask)
print("\n🎉 全部資料集轉換完成Cityscapes 格式 OK")