KNEO-Academy/src/controllers/inference_controller.py

260 lines
12 KiB
Python

# src/controllers/inference_controller.py
import os, queue, cv2, json
from PyQt5.QtWidgets import QMessageBox, QApplication
from PyQt5.QtCore import QTimer, Qt
from src.models.inference_worker import InferenceWorkerThread
from src.config import UTILS_DIR, FW_DIR, DongleModelMap
class InferenceController:
def __init__(self, main_window, device_controller):
self.main_window = main_window
self.device_controller = device_controller
self.inference_worker = None
self.inference_queue = queue.Queue(maxsize=10)
self.current_tool_config = None
self.previous_tool_config = None
self._camera_was_active = False
# 儲存原始影格尺寸,用於邊界框縮放計算
self.original_frame_width = 640 # 預設值
self.original_frame_height = 480 # 預設值
def select_tool(self, tool_config):
"""選擇AI工具並配置推論"""
try:
print("選擇工具:", tool_config.get("display_name"))
self.current_tool_config = tool_config
# 獲取模式和模型名稱
mode = tool_config.get("mode", "")
model_name = tool_config.get("model_name", "")
# 載入詳細模型配置
model_path = os.path.join(UTILS_DIR, mode, model_name)
model_config_path = os.path.join(model_path, "config.json")
if os.path.exists(model_config_path):
try:
with open(model_config_path, "r", encoding="utf-8") as f:
detailed_config = json.load(f)
tool_config = {**tool_config, **detailed_config}
except Exception as e:
print(f"讀取模型配置時發生錯誤: {e}")
# 獲取工具輸入類型
input_info = tool_config.get("input_info", {})
tool_type = input_info.get("type", "video")
once_mode = True if tool_type == "image" else False
# 檢查是否從視訊模式切換到圖片模式,或從圖片模式切換到視訊模式
previous_tool_type = "video"
if hasattr(self, 'previous_tool_config') and self.previous_tool_config:
previous_input_info = self.previous_tool_config.get("input_info", {})
previous_tool_type = previous_input_info.get("type", "video")
# 清空推論佇列,確保在模式切換時不會使用舊數據
self._clear_inference_queue()
# 儲存當前工具類型以供下次比較
self.previous_tool_config = tool_config
# 準備輸入參數
input_params = tool_config.get("input_parameters", {}).copy()
# Configure device-related settings
selected_device = self.device_controller.get_selected_device()
if selected_device:
# Get usb_port_id (check if it's a dictionary or object)
if isinstance(selected_device, dict):
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
product_id = selected_device.get("product_id", "unknown")
else:
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
product_id = getattr(selected_device, "product_id", "unknown")
# Ensure product_id is in the right format for lookup
# Convert to lowercase hex string if it's a number
if isinstance(product_id, int):
product_id = hex(product_id).lower()
# If it's a string but doesn't start with '0x', add it
elif isinstance(product_id, str) and not product_id.startswith('0x'):
try:
# Try to convert to int first, then to hex format
product_id = hex(int(product_id, 0)).lower()
except ValueError:
# If conversion fails, keep as is
pass
# Map product_id to dongle type/series
dongle = DongleModelMap.get(product_id, "unknown")
print(f"Selected device: product_id={product_id}, mapped to={dongle}")
# Verify device compatibility
compatible_devices = tool_config.get("compatible_devices", [])
if compatible_devices and dongle not in compatible_devices:
msgBox = QMessageBox(self.main_window)
msgBox.setIcon(QMessageBox.Warning)
msgBox.setWindowTitle("Device Incompatible")
msgBox.setText(f"The selected model does not support {dongle} device.\nSupported devices: {', '.join(compatible_devices)}")
msgBox.setStyleSheet("QLabel { color: white; } QMessageBox { background-color: #2b2b2b; }")
msgBox.exec_()
return False
# Configure firmware paths
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
input_params["scpu_path"] = scpu_path
input_params["ncpu_path"] = ncpu_path
else:
# Default device handling
devices = self.device_controller.connected_devices
if devices and len(devices) > 0:
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
print("Warning: No device specifically selected, using first available device")
else:
input_params["usb_port_id"] = 0
print("Warning: No connected devices, using default usb_port_id 0")
# Handle file inputs for image/voice modes
if tool_type in ["image", "voice"]:
if hasattr(self.main_window, "destination") and self.main_window.destination:
input_params["file_path"] = self.main_window.destination
if tool_type == "image":
uploaded_img = cv2.imread(self.main_window.destination)
if uploaded_img is not None:
if not self.inference_queue.full():
self.inference_queue.put(uploaded_img)
print("Uploaded image added to inference queue")
else:
print("Warning: inference queue is full")
else:
print("Warning: Unable to read uploaded image")
else:
input_params["file_path"] = ""
print(f"Warning: {tool_type} mode requires a file input, but no file has been uploaded.")
# Add model file path
if "model_file" in tool_config:
model_file = tool_config["model_file"]
model_file_path = os.path.join(model_path, model_file)
input_params["model"] = model_file_path
print("Input parameters:", input_params)
# Stop existing inference worker if running
if self.inference_worker:
self.inference_worker.stop()
self.inference_worker = None
# Create new inference worker
self.inference_worker = InferenceWorkerThread(
self.inference_queue,
mode,
model_name,
min_interval=0.5,
mse_threshold=500,
once_mode=once_mode
)
self.inference_worker.input_params = input_params
self.inference_worker.inference_result_signal.connect(self.main_window.handle_inference_result)
self.inference_worker.start()
print(f"Inference worker started for module: {mode}/{model_name}")
# Start camera if needed
if tool_type == "video":
# If camera was previously active but disconnected for image processing
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
# Reconnect the signal
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
self.main_window.media_controller.update_image
)
print("Camera reconnected for video processing")
else:
# Start camera normally
self.main_window.media_controller.start_camera()
else:
# For image tools, temporarily pause the camera but don't stop it completely
# This allows switching back to video tools without restarting the camera
if self.main_window.media_controller.video_thread is not None:
# Save current state to indicate camera was running
self._camera_was_active = True
# Disconnect signal to prevent processing frames during image inference
self.main_window.media_controller.video_thread.change_pixmap_signal.disconnect()
print("Camera paused for image processing")
else:
self._camera_was_active = False
return True
except Exception as e:
print(f"選擇工具時發生錯誤: {e}")
import traceback
print(traceback.format_exc())
def _clear_inference_queue(self):
"""清空推論佇列中的所有數據"""
try:
# 清空現有佇列
while not self.inference_queue.empty():
try:
self.inference_queue.get_nowait()
except queue.Empty:
break
print("推論佇列已清空")
except Exception as e:
print(f"清空推論佇列時發生錯誤: {e}")
def add_frame_to_queue(self, frame):
"""將影格添加到推論佇列"""
try:
# 更新原始影格尺寸
if frame is not None and hasattr(frame, 'shape'):
height, width = frame.shape[:2]
self.original_frame_width = width
self.original_frame_height = height
# 添加到佇列
if not self.inference_queue.full():
self.inference_queue.put(frame)
except Exception as e:
print(f"添加影格到佇列時發生錯誤: {e}")
import traceback
print(traceback.format_exc())
def stop_inference(self):
"""Stop the inference worker"""
if self.inference_worker:
self.inference_worker.stop()
self.inference_worker = None
def process_uploaded_image(self, file_path):
"""處理上傳的圖片並進行推論"""
try:
if not os.path.exists(file_path):
print(f"錯誤: 檔案不存在 {file_path}")
return
# 清空推論佇列,確保只處理最新的圖片
self._clear_inference_queue()
# 讀取圖片
img = cv2.imread(file_path)
if img is None:
print(f"錯誤: 無法讀取圖片 {file_path}")
return
# 更新推論工作器參數
if self.inference_worker:
self.inference_worker.input_params["file_path"] = file_path
# 將圖片添加到推論佇列
if not self.inference_queue.full():
self.inference_queue.put(img)
print(f"已將圖片 {file_path} 添加到推論佇列")
else:
print("警告: 推論佇列已滿")
else:
print("錯誤: 推論工作器未初始化")
except Exception as e:
print(f"處理上傳圖片時發生錯誤: {e}")
import traceback
print(traceback.format_exc())