260 lines
12 KiB
Python
260 lines
12 KiB
Python
# src/controllers/inference_controller.py
|
|
import os, queue, cv2, json
|
|
from PyQt5.QtWidgets import QMessageBox, QApplication
|
|
from PyQt5.QtCore import QTimer, Qt
|
|
|
|
from src.models.inference_worker import InferenceWorkerThread
|
|
from src.config import UTILS_DIR, FW_DIR, DongleModelMap
|
|
|
|
class InferenceController:
|
|
def __init__(self, main_window, device_controller):
|
|
self.main_window = main_window
|
|
self.device_controller = device_controller
|
|
self.inference_worker = None
|
|
self.inference_queue = queue.Queue(maxsize=10)
|
|
self.current_tool_config = None
|
|
self.previous_tool_config = None
|
|
self._camera_was_active = False
|
|
# 儲存原始影格尺寸,用於邊界框縮放計算
|
|
self.original_frame_width = 640 # 預設值
|
|
self.original_frame_height = 480 # 預設值
|
|
|
|
def select_tool(self, tool_config):
|
|
"""選擇AI工具並配置推論"""
|
|
try:
|
|
print("選擇工具:", tool_config.get("display_name"))
|
|
self.current_tool_config = tool_config
|
|
|
|
# 獲取模式和模型名稱
|
|
mode = tool_config.get("mode", "")
|
|
model_name = tool_config.get("model_name", "")
|
|
|
|
# 載入詳細模型配置
|
|
model_path = os.path.join(UTILS_DIR, mode, model_name)
|
|
model_config_path = os.path.join(model_path, "config.json")
|
|
|
|
if os.path.exists(model_config_path):
|
|
try:
|
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
|
detailed_config = json.load(f)
|
|
tool_config = {**tool_config, **detailed_config}
|
|
except Exception as e:
|
|
print(f"讀取模型配置時發生錯誤: {e}")
|
|
|
|
# 獲取工具輸入類型
|
|
input_info = tool_config.get("input_info", {})
|
|
tool_type = input_info.get("type", "video")
|
|
once_mode = True if tool_type == "image" else False
|
|
|
|
# 檢查是否從視訊模式切換到圖片模式,或從圖片模式切換到視訊模式
|
|
previous_tool_type = "video"
|
|
if hasattr(self, 'previous_tool_config') and self.previous_tool_config:
|
|
previous_input_info = self.previous_tool_config.get("input_info", {})
|
|
previous_tool_type = previous_input_info.get("type", "video")
|
|
|
|
# 清空推論佇列,確保在模式切換時不會使用舊數據
|
|
self._clear_inference_queue()
|
|
|
|
# 儲存當前工具類型以供下次比較
|
|
self.previous_tool_config = tool_config
|
|
|
|
# 準備輸入參數
|
|
input_params = tool_config.get("input_parameters", {}).copy()
|
|
|
|
# Configure device-related settings
|
|
selected_device = self.device_controller.get_selected_device()
|
|
if selected_device:
|
|
# Get usb_port_id (check if it's a dictionary or object)
|
|
if isinstance(selected_device, dict):
|
|
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
|
|
product_id = selected_device.get("product_id", "unknown")
|
|
else:
|
|
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
|
|
product_id = getattr(selected_device, "product_id", "unknown")
|
|
|
|
# Ensure product_id is in the right format for lookup
|
|
# Convert to lowercase hex string if it's a number
|
|
if isinstance(product_id, int):
|
|
product_id = hex(product_id).lower()
|
|
# If it's a string but doesn't start with '0x', add it
|
|
elif isinstance(product_id, str) and not product_id.startswith('0x'):
|
|
try:
|
|
# Try to convert to int first, then to hex format
|
|
product_id = hex(int(product_id, 0)).lower()
|
|
except ValueError:
|
|
# If conversion fails, keep as is
|
|
pass
|
|
|
|
# Map product_id to dongle type/series
|
|
dongle = DongleModelMap.get(product_id, "unknown")
|
|
print(f"Selected device: product_id={product_id}, mapped to={dongle}")
|
|
|
|
# Verify device compatibility
|
|
compatible_devices = tool_config.get("compatible_devices", [])
|
|
if compatible_devices and dongle not in compatible_devices:
|
|
msgBox = QMessageBox(self.main_window)
|
|
msgBox.setIcon(QMessageBox.Warning)
|
|
msgBox.setWindowTitle("Device Incompatible")
|
|
msgBox.setText(f"The selected model does not support {dongle} device.\nSupported devices: {', '.join(compatible_devices)}")
|
|
msgBox.setStyleSheet("QLabel { color: white; } QMessageBox { background-color: #2b2b2b; }")
|
|
msgBox.exec_()
|
|
return False
|
|
|
|
# Configure firmware paths
|
|
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
|
|
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
|
|
input_params["scpu_path"] = scpu_path
|
|
input_params["ncpu_path"] = ncpu_path
|
|
else:
|
|
# Default device handling
|
|
devices = self.device_controller.connected_devices
|
|
if devices and len(devices) > 0:
|
|
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
|
|
print("Warning: No device specifically selected, using first available device")
|
|
else:
|
|
input_params["usb_port_id"] = 0
|
|
print("Warning: No connected devices, using default usb_port_id 0")
|
|
|
|
# Handle file inputs for image/voice modes
|
|
if tool_type in ["image", "voice"]:
|
|
if hasattr(self.main_window, "destination") and self.main_window.destination:
|
|
input_params["file_path"] = self.main_window.destination
|
|
if tool_type == "image":
|
|
uploaded_img = cv2.imread(self.main_window.destination)
|
|
if uploaded_img is not None:
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(uploaded_img)
|
|
print("Uploaded image added to inference queue")
|
|
else:
|
|
print("Warning: inference queue is full")
|
|
else:
|
|
print("Warning: Unable to read uploaded image")
|
|
else:
|
|
input_params["file_path"] = ""
|
|
print(f"Warning: {tool_type} mode requires a file input, but no file has been uploaded.")
|
|
|
|
# Add model file path
|
|
if "model_file" in tool_config:
|
|
model_file = tool_config["model_file"]
|
|
model_file_path = os.path.join(model_path, model_file)
|
|
input_params["model"] = model_file_path
|
|
|
|
print("Input parameters:", input_params)
|
|
|
|
# Stop existing inference worker if running
|
|
if self.inference_worker:
|
|
self.inference_worker.stop()
|
|
self.inference_worker = None
|
|
|
|
# Create new inference worker
|
|
self.inference_worker = InferenceWorkerThread(
|
|
self.inference_queue,
|
|
mode,
|
|
model_name,
|
|
min_interval=0.5,
|
|
mse_threshold=500,
|
|
once_mode=once_mode
|
|
)
|
|
self.inference_worker.input_params = input_params
|
|
self.inference_worker.inference_result_signal.connect(self.main_window.handle_inference_result)
|
|
self.inference_worker.start()
|
|
print(f"Inference worker started for module: {mode}/{model_name}")
|
|
|
|
# Start camera if needed
|
|
if tool_type == "video":
|
|
# If camera was previously active but disconnected for image processing
|
|
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
|
|
# Reconnect the signal
|
|
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
|
|
self.main_window.media_controller.update_image
|
|
)
|
|
print("Camera reconnected for video processing")
|
|
else:
|
|
# Start camera normally
|
|
self.main_window.media_controller.start_camera()
|
|
else:
|
|
# For image tools, temporarily pause the camera but don't stop it completely
|
|
# This allows switching back to video tools without restarting the camera
|
|
if self.main_window.media_controller.video_thread is not None:
|
|
# Save current state to indicate camera was running
|
|
self._camera_was_active = True
|
|
# Disconnect signal to prevent processing frames during image inference
|
|
self.main_window.media_controller.video_thread.change_pixmap_signal.disconnect()
|
|
print("Camera paused for image processing")
|
|
else:
|
|
self._camera_was_active = False
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"選擇工具時發生錯誤: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
|
|
def _clear_inference_queue(self):
|
|
"""清空推論佇列中的所有數據"""
|
|
try:
|
|
# 清空現有佇列
|
|
while not self.inference_queue.empty():
|
|
try:
|
|
self.inference_queue.get_nowait()
|
|
except queue.Empty:
|
|
break
|
|
print("推論佇列已清空")
|
|
except Exception as e:
|
|
print(f"清空推論佇列時發生錯誤: {e}")
|
|
|
|
def add_frame_to_queue(self, frame):
|
|
"""將影格添加到推論佇列"""
|
|
try:
|
|
# 更新原始影格尺寸
|
|
if frame is not None and hasattr(frame, 'shape'):
|
|
height, width = frame.shape[:2]
|
|
self.original_frame_width = width
|
|
self.original_frame_height = height
|
|
|
|
# 添加到佇列
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(frame)
|
|
except Exception as e:
|
|
print(f"添加影格到佇列時發生錯誤: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
|
|
def stop_inference(self):
|
|
"""Stop the inference worker"""
|
|
if self.inference_worker:
|
|
self.inference_worker.stop()
|
|
self.inference_worker = None
|
|
|
|
def process_uploaded_image(self, file_path):
|
|
"""處理上傳的圖片並進行推論"""
|
|
try:
|
|
if not os.path.exists(file_path):
|
|
print(f"錯誤: 檔案不存在 {file_path}")
|
|
return
|
|
|
|
# 清空推論佇列,確保只處理最新的圖片
|
|
self._clear_inference_queue()
|
|
|
|
# 讀取圖片
|
|
img = cv2.imread(file_path)
|
|
if img is None:
|
|
print(f"錯誤: 無法讀取圖片 {file_path}")
|
|
return
|
|
|
|
# 更新推論工作器參數
|
|
if self.inference_worker:
|
|
self.inference_worker.input_params["file_path"] = file_path
|
|
|
|
# 將圖片添加到推論佇列
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(img)
|
|
print(f"已將圖片 {file_path} 添加到推論佇列")
|
|
else:
|
|
print("警告: 推論佇列已滿")
|
|
else:
|
|
print("錯誤: 推論工作器未初始化")
|
|
except Exception as e:
|
|
print(f"處理上傳圖片時發生錯誤: {e}")
|
|
import traceback
|
|
print(traceback.format_exc()) |