#!/usr/bin/env python3 """ Improved YOLO postprocessing with better error handling and filtering. 改進的 YOLO 後處理,包含更好的錯誤處理和過濾機制。 """ import numpy as np from typing import List from collections import defaultdict # 假設這些類別已經在原始檔案中定義 from core.functions.Multidongle import BoundingBox, ObjectDetectionResult class ImprovedYOLOPostProcessor: """改進的 YOLO 後處理器,包含異常檢測和過濾""" def __init__(self, options): self.options = options self.max_detections_total = 500 # 總檢測數量限制 self.max_detections_per_class = 50 # 每類檢測數量限制 self.min_box_area = 4 # 最小邊界框面積 self.max_score = 10.0 # 最大允許分數(用於檢測異常) def _is_valid_box(self, x1, y1, x2, y2, score, class_id): """檢查邊界框是否有效""" # 基本座標檢查 if x1 < 0 or y1 < 0 or x1 >= x2 or y1 >= y2: return False, "Invalid coordinates" # 面積檢查 area = (x2 - x1) * (y2 - y1) if area < self.min_box_area: return False, f"Box too small (area={area})" # 分數檢查 if score <= 0 or score > self.max_score: return False, f"Invalid score ({score})" # 類別檢查 if class_id < 0 or (self.options.class_names and class_id >= len(self.options.class_names)): return False, f"Invalid class_id ({class_id})" return True, "Valid" def _filter_excessive_detections(self, boxes: List[BoundingBox]) -> List[BoundingBox]: """過濾過多的檢測結果""" if len(boxes) <= self.max_detections_total: return boxes print(f"WARNING: Too many detections ({len(boxes)}), filtering to {self.max_detections_total}") # 按分數排序,保留最高分數的檢測 boxes.sort(key=lambda x: x.score, reverse=True) return boxes[:self.max_detections_total] def _filter_by_class_count(self, boxes: List[BoundingBox]) -> List[BoundingBox]: """限制每個類別的檢測數量""" class_counts = defaultdict(list) # 按類別分組 for box in boxes: class_counts[box.class_num].append(box) filtered_boxes = [] for class_id, class_boxes in class_counts.items(): # 按分數排序,保留最高分數的檢測 class_boxes.sort(key=lambda x: x.score, reverse=True) # 限制每個類別的數量 keep_count = min(len(class_boxes), self.max_detections_per_class) if len(class_boxes) > self.max_detections_per_class: class_name = class_boxes[0].class_name print(f"WARNING: Too many {class_name} detections ({len(class_boxes)}), keeping top {keep_count}") filtered_boxes.extend(class_boxes[:keep_count]) return filtered_boxes def _detect_anomalous_pattern(self, boxes: List[BoundingBox]) -> bool: """檢測異常的檢測模式""" if not boxes: return False # 檢查是否有大量相同座標的檢測 coord_counts = defaultdict(int) for box in boxes: coord_key = (box.x1, box.y1, box.x2, box.y2) coord_counts[coord_key] += 1 max_coord_count = max(coord_counts.values()) if max_coord_count > 10: print(f"WARNING: Anomalous pattern detected - {max_coord_count} boxes with same coordinates") return True # 檢查分數分布 scores = [box.score for box in boxes] if scores: avg_score = np.mean(scores) if avg_score > 2.0: # 分數過高可能表示對數空間 print(f"WARNING: Unusually high average score: {avg_score:.3f}") return True return False def process_yolo_output(self, inference_output_list: List, hardware_preproc_info=None, version="v3") -> ObjectDetectionResult: """改進的 YOLO 輸出處理""" boxes = [] invalid_box_count = 0 try: if not inference_output_list or len(inference_output_list) == 0: return ObjectDetectionResult( class_count=len(self.options.class_names) if self.options.class_names else 0, box_count=0, box_list=[] ) print(f"DEBUG: Processing {len(inference_output_list)} YOLO output nodes") for i, output in enumerate(inference_output_list): try: # 提取數組數據 if hasattr(output, 'ndarray'): arr = output.ndarray elif hasattr(output, 'flatten'): arr = output elif isinstance(output, np.ndarray): arr = output else: print(f"WARNING: Unknown output type for node {i}: {type(output)}") continue # 檢查數組形狀 if not hasattr(arr, 'shape'): print(f"WARNING: Output node {i} has no shape attribute") continue print(f"DEBUG: Output node {i} shape: {arr.shape}") # YOLOv5 格式處理: [batch, num_detections, features] if len(arr.shape) == 3: batch_size, num_detections, num_features = arr.shape print(f"DEBUG: YOLOv5 format: {batch_size}x{num_detections}x{num_features}") # 檢查異常大的檢測數量 if num_detections > 10000: print(f"WARNING: Extremely high detection count: {num_detections}, limiting to 1000") num_detections = 1000 detections = arr[0] # 只處理第一批次 for det_idx in range(min(num_detections, 1000)): # 限制處理數量 detection = detections[det_idx] try: # 提取座標和信心度 x_center = float(detection[0]) y_center = float(detection[1]) width = float(detection[2]) height = float(detection[3]) obj_conf = float(detection[4]) # 檢查是否是有效數值 if not all(np.isfinite([x_center, y_center, width, height, obj_conf])): invalid_box_count += 1 continue # 跳過低信心度檢測 if obj_conf < self.options.threshold: continue # 尋找最佳類別 class_probs = detection[5:] if num_features > 5 else [] if len(class_probs) > 0: class_scores = class_probs * obj_conf best_class = int(np.argmax(class_scores)) best_score = float(class_scores[best_class]) if best_score < self.options.threshold: continue else: best_class = 0 best_score = obj_conf # 座標轉換 x1 = int(x_center - width / 2) y1 = int(y_center - height / 2) x2 = int(x_center + width / 2) y2 = int(y_center + height / 2) # 驗證邊界框 is_valid, reason = self._is_valid_box(x1, y1, x2, y2, best_score, best_class) if not is_valid: invalid_box_count += 1 if invalid_box_count <= 5: # 只報告前5個錯誤 print(f"DEBUG: Invalid box rejected: {reason}") continue # 獲取類別名稱 if self.options.class_names and best_class < len(self.options.class_names): class_name = self.options.class_names[best_class] else: class_name = f"Class_{best_class}" box = BoundingBox( x1=max(0, x1), y1=max(0, y1), x2=x2, y2=y2, score=best_score, class_num=best_class, class_name=class_name ) boxes.append(box) except Exception as e: invalid_box_count += 1 if invalid_box_count <= 5: print(f"DEBUG: Error processing detection {det_idx}: {e}") continue elif len(arr.shape) == 2: # 2D 格式處理 print(f"DEBUG: 2D YOLO output: {arr.shape}") num_detections, num_features = arr.shape if num_detections > 1000: print(f"WARNING: Too many 2D detections: {num_detections}, limiting to 1000") num_detections = 1000 for det_idx in range(min(num_detections, 1000)): detection = arr[det_idx] try: if num_features >= 6: x_center = float(detection[0]) y_center = float(detection[1]) width = float(detection[2]) height = float(detection[3]) confidence = float(detection[4]) class_id = int(detection[5]) if not all(np.isfinite([x_center, y_center, width, height, confidence])): invalid_box_count += 1 continue if confidence > self.options.threshold: x1 = int(x_center - width / 2) y1 = int(y_center - height / 2) x2 = int(x_center + width / 2) y2 = int(y_center + height / 2) is_valid, reason = self._is_valid_box(x1, y1, x2, y2, confidence, class_id) if not is_valid: invalid_box_count += 1 continue class_name = self.options.class_names[class_id] if class_id < len(self.options.class_names) else f"Class_{class_id}" box = BoundingBox( x1=max(0, x1), y1=max(0, y1), x2=x2, y2=y2, score=confidence, class_num=class_id, class_name=class_name ) boxes.append(box) except Exception as e: invalid_box_count += 1 continue else: # 回退處理 flat = arr.flatten() print(f"DEBUG: Fallback processing for flat array size: {len(flat)}") # 限制處理的數據量 if len(flat) > 6000: # 1000 boxes * 6 values print(f"WARNING: Large flat array ({len(flat)}), limiting processing") flat = flat[:6000] step = 6 for j in range(0, len(flat) - step + 1, step): try: x1, y1, x2, y2, conf, cls = flat[j:j+6] if not all(np.isfinite([x1, y1, x2, y2, conf])): invalid_box_count += 1 continue if conf > self.options.threshold: class_id = int(cls) is_valid, reason = self._is_valid_box(x1, y1, x2, y2, conf, class_id) if not is_valid: invalid_box_count += 1 continue class_name = self.options.class_names[class_id] if class_id < len(self.options.class_names) else f"Class_{class_id}" box = BoundingBox( x1=max(0, int(x1)), y1=max(0, int(y1)), x2=int(x2), y2=int(y2), score=float(conf), class_num=class_id, class_name=class_name ) boxes.append(box) except Exception as e: invalid_box_count += 1 continue except Exception as e: print(f"ERROR: Error processing output node {i}: {e}") continue # 報告統計信息 if invalid_box_count > 0: print(f"INFO: Rejected {invalid_box_count} invalid detections") print(f"DEBUG: Raw detection count: {len(boxes)}") # 檢測異常模式 if self._detect_anomalous_pattern(boxes): print("WARNING: Anomalous detection pattern detected, applying aggressive filtering") # 更嚴格的過濾 boxes = [box for box in boxes if box.score < 2.0 and box.x1 != box.x2 and box.y1 != box.y2] # 應用過濾 boxes = self._filter_excessive_detections(boxes) boxes = self._filter_by_class_count(boxes) # 應用 NMS if boxes and len(boxes) > 1: boxes = self._apply_nms(boxes) print(f"INFO: Final detection count: {len(boxes)}") # 創建統計報告 if boxes: class_stats = defaultdict(int) for box in boxes: class_stats[box.class_name] += 1 print("Detection summary:") for class_name, count in sorted(class_stats.items()): print(f" {class_name}: {count}") except Exception as e: print(f"ERROR: Critical error in YOLO postprocessing: {e}") import traceback traceback.print_exc() boxes = [] return ObjectDetectionResult( class_count=len(self.options.class_names) if self.options.class_names else 1, box_count=len(boxes), box_list=boxes ) def _apply_nms(self, boxes: List[BoundingBox]) -> List[BoundingBox]: """改進的非極大值抑制""" if not boxes or len(boxes) <= 1: return boxes try: # 按類別分組 class_boxes = defaultdict(list) for box in boxes: class_boxes[box.class_num].append(box) final_boxes = [] for class_id, class_box_list in class_boxes.items(): if len(class_box_list) <= 1: final_boxes.extend(class_box_list) continue # 按信心度排序 class_box_list.sort(key=lambda x: x.score, reverse=True) keep = [] while class_box_list and len(keep) < self.max_detections_per_class: current_box = class_box_list.pop(0) keep.append(current_box) # 移除高 IoU 的框 remaining = [] for box in class_box_list: iou = self._calculate_iou(current_box, box) if iou <= self.options.nms_threshold: remaining.append(box) class_box_list = remaining final_boxes.extend(keep) print(f"DEBUG: NMS reduced {len(boxes)} to {len(final_boxes)} boxes") return final_boxes except Exception as e: print(f"ERROR: NMS failed: {e}") return boxes[:self.max_detections_total] # 回退到簡單限制 def _calculate_iou(self, box1: BoundingBox, box2: BoundingBox) -> float: """計算兩個邊界框的 IoU""" try: # 計算交集 x1 = max(box1.x1, box2.x1) y1 = max(box1.y1, box2.y1) x2 = min(box1.x2, box2.x2) y2 = min(box1.y2, box2.y2) if x2 <= x1 or y2 <= y1: return 0.0 intersection = (x2 - x1) * (y2 - y1) # 計算聯集 area1 = (box1.x2 - box1.x1) * (box1.y2 - box1.y1) area2 = (box2.x2 - box2.x1) * (box2.y2 - box2.y1) union = area1 + area2 - intersection if union <= 0: return 0.0 return intersection / union except Exception: return 0.0 # 測試函數 if __name__ == "__main__": from core.functions.Multidongle import PostProcessorOptions, PostProcessType # 創建測試選項 options = PostProcessorOptions( postprocess_type=PostProcessType.YOLO_V5, threshold=0.3, class_names=["person", "bicycle", "car", "motorbike", "aeroplane"], nms_threshold=0.45, max_detections_per_class=20 ) processor = ImprovedYOLOPostProcessor(options) print("ImprovedYOLOPostProcessor initialized successfully!")