Feature: Add custom model upload and inference support
- device_controller: Add connect_device, disconnect_device, get_device_group methods - inference_controller: Add select_custom_tool method for custom model inference - custom_inference_worker: New worker thread for custom model inference with YOLO V5 post-processing - custom_model_block: New UI component for uploading custom .nef model and firmware files 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
c8be1db25e
commit
09156cce94
@ -3,33 +3,49 @@ from PyQt5.QtWidgets import QWidget, QListWidgetItem
|
||||
from PyQt5.QtGui import QPixmap, QIcon
|
||||
from PyQt5.QtCore import Qt
|
||||
import os
|
||||
import kp # 新增 kp 模組的引入
|
||||
|
||||
from src.services.device_service import check_available_device
|
||||
from src.config import UXUI_ASSETS, DongleModelMap, DongleIconMap
|
||||
from src.config import UXUI_ASSETS, DongleModelMap, DongleIconMap, FW_DIR
|
||||
|
||||
class DeviceController:
|
||||
def __init__(self, main_window):
|
||||
self.main_window = main_window
|
||||
self.selected_device = None
|
||||
self.connected_devices = []
|
||||
self.device_group = None # 新增儲存連接的 device_group
|
||||
|
||||
def refresh_devices(self):
|
||||
"""Refresh the list of connected devices"""
|
||||
try:
|
||||
print("Refreshing devices...")
|
||||
print("[CTRL] Refreshing devices...")
|
||||
device_descriptors = check_available_device()
|
||||
print("[CTRL] check_available_device 已返回")
|
||||
print(f"[CTRL] device_descriptors 類型: {type(device_descriptors)}")
|
||||
|
||||
# 分開訪問屬性以便調試
|
||||
print("[CTRL] 嘗試訪問 device_descriptor_number...")
|
||||
desc_num = device_descriptors.device_descriptor_number
|
||||
print(f"[CTRL] device_descriptor_number: {desc_num}")
|
||||
|
||||
self.connected_devices = []
|
||||
# print(self.connected_devices)
|
||||
|
||||
if device_descriptors.device_descriptor_number > 0:
|
||||
print("[DEBUG] 開始 parse_and_store_devices...")
|
||||
self.parse_and_store_devices(device_descriptors.device_descriptor_list)
|
||||
print("[DEBUG] parse_and_store_devices 完成")
|
||||
print("[DEBUG] 開始 display_devices...")
|
||||
self.display_devices(device_descriptors.device_descriptor_list)
|
||||
print("[DEBUG] display_devices 完成")
|
||||
return True
|
||||
else:
|
||||
print("[DEBUG] 沒有檢測到設備")
|
||||
self.main_window.show_no_device_gif()
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Error in refresh_devices: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def parse_and_store_devices(self, devices):
|
||||
@ -114,17 +130,104 @@ class DeviceController:
|
||||
return self.selected_device
|
||||
|
||||
def select_device(self, device, list_item, list_widget):
|
||||
"""Select a device and update UI"""
|
||||
"""選擇設備(不自動連接和載入 firmware)"""
|
||||
self.selected_device = device
|
||||
print("Selected dongle:", device)
|
||||
print("選擇 dongle:", device)
|
||||
|
||||
# Update list item visual selection
|
||||
# 更新列表項目的視覺選擇
|
||||
for index in range(list_widget.count()):
|
||||
item = list_widget.item(index)
|
||||
widget = list_widget.itemWidget(item)
|
||||
if widget: # Check if widget exists before setting style
|
||||
if widget: # 檢查 widget 是否存在再設定樣式
|
||||
widget.setStyleSheet("background: none;")
|
||||
|
||||
list_item_widget = list_widget.itemWidget(list_item)
|
||||
if list_item_widget: # Check if widget exists before setting style
|
||||
if list_item_widget: # 檢查 widget 是否存在再設定樣式
|
||||
list_item_widget.setStyleSheet("background-color: lightblue;")
|
||||
|
||||
def connect_device(self):
|
||||
"""連接選定的設備並上傳固件"""
|
||||
if not self.selected_device:
|
||||
print("未選擇設備,無法連接")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 取得 USB port ID
|
||||
if isinstance(self.selected_device, dict):
|
||||
usb_port_id = self.selected_device.get("usb_port_id", 0)
|
||||
product_id = self.selected_device.get("product_id", 0)
|
||||
else:
|
||||
usb_port_id = getattr(self.selected_device, "usb_port_id", 0)
|
||||
product_id = getattr(self.selected_device, "product_id", 0)
|
||||
|
||||
# 將 product_id 轉換為小寫十六進制字串
|
||||
if isinstance(product_id, int):
|
||||
product_id = hex(product_id).lower()
|
||||
elif isinstance(product_id, str) and not product_id.startswith('0x'):
|
||||
try:
|
||||
product_id = hex(int(product_id, 0)).lower()
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# 對應 product_id 到 dongle 類型
|
||||
dongle = DongleModelMap.get(product_id, "unknown")
|
||||
print(f"連接設備: product_id={product_id}, mapped to={dongle}")
|
||||
|
||||
# 設置固件路徑
|
||||
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
|
||||
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
|
||||
|
||||
# 確認固件文件是否存在
|
||||
if not os.path.exists(scpu_path) or not os.path.exists(ncpu_path):
|
||||
print(f"固件文件不存在: {scpu_path} 或 {ncpu_path}")
|
||||
return False
|
||||
|
||||
# 連接設備
|
||||
print('[連接設備]')
|
||||
self.device_group = kp.core.connect_devices(usb_port_ids=[usb_port_id])
|
||||
print(' - 連接成功')
|
||||
|
||||
# # 設置超時
|
||||
# print('[設置超時]')
|
||||
# kp.core.set_timeout(device_group=self.device_group, milliseconds=10000)
|
||||
# print(' - 設置成功')
|
||||
|
||||
# 上傳固件
|
||||
print('[上傳固件]')
|
||||
kp.core.load_firmware_from_file(
|
||||
device_group=self.device_group,
|
||||
scpu_fw_path=scpu_path,
|
||||
ncpu_fw_path=ncpu_path
|
||||
)
|
||||
print(' - 上傳成功')
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"連接設備時發生錯誤: {e}")
|
||||
# 發生錯誤時嘗試清理
|
||||
if self.device_group:
|
||||
try:
|
||||
kp.core.disconnect_devices(device_group=self.device_group)
|
||||
except Exception:
|
||||
pass
|
||||
self.device_group = None
|
||||
return False
|
||||
|
||||
def disconnect_device(self):
|
||||
"""中斷與設備的連接"""
|
||||
if self.device_group:
|
||||
try:
|
||||
print('[中斷設備連接]')
|
||||
kp.core.disconnect_devices(device_group=self.device_group)
|
||||
print(' - 已中斷連接')
|
||||
self.device_group = None
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"中斷設備連接時發生錯誤: {e}")
|
||||
self.device_group = None
|
||||
return False
|
||||
return True # 如果沒有連接的設備,視為成功
|
||||
|
||||
def get_device_group(self):
|
||||
"""獲取已連接的設備群組"""
|
||||
return self.device_group
|
||||
@ -2,8 +2,10 @@
|
||||
import os, queue, cv2, json
|
||||
from PyQt5.QtWidgets import QMessageBox, QApplication
|
||||
from PyQt5.QtCore import QTimer, Qt
|
||||
import kp # 新增 kp 模組的引入
|
||||
|
||||
from src.models.inference_worker import InferenceWorkerThread
|
||||
from src.models.custom_inference_worker import CustomInferenceWorkerThread
|
||||
from src.config import UTILS_DIR, FW_DIR, DongleModelMap
|
||||
|
||||
class InferenceController:
|
||||
@ -11,13 +13,14 @@ class InferenceController:
|
||||
self.main_window = main_window
|
||||
self.device_controller = device_controller
|
||||
self.inference_worker = None
|
||||
self.inference_queue = queue.Queue(maxsize=10)
|
||||
self.inference_queue = queue.Queue(maxsize=5)
|
||||
self.current_tool_config = None
|
||||
self.previous_tool_config = None
|
||||
self._camera_was_active = False
|
||||
# 儲存原始影格尺寸,用於邊界框縮放計算
|
||||
self.original_frame_width = 640 # 預設值
|
||||
self.original_frame_height = 480 # 預設值
|
||||
self.model_descriptor = None
|
||||
|
||||
def select_tool(self, tool_config):
|
||||
"""選擇AI工具並配置推論"""
|
||||
@ -61,6 +64,11 @@ class InferenceController:
|
||||
# 準備輸入參數
|
||||
input_params = tool_config.get("input_parameters", {}).copy()
|
||||
|
||||
# 取得連接的設備群組
|
||||
device_group = self.device_controller.get_device_group()
|
||||
# 新增設備群組到輸入參數
|
||||
input_params["device_group"] = device_group
|
||||
|
||||
# Configure device-related settings
|
||||
selected_device = self.device_controller.get_selected_device()
|
||||
if selected_device:
|
||||
@ -100,7 +108,7 @@ class InferenceController:
|
||||
msgBox.exec_()
|
||||
return False
|
||||
|
||||
# Configure firmware paths
|
||||
# 僅添加固件路徑作為參考,實際不再使用 (因為裝置已經在選擇時連接)
|
||||
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
|
||||
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
|
||||
input_params["scpu_path"] = scpu_path
|
||||
@ -139,6 +147,22 @@ class InferenceController:
|
||||
model_file_path = os.path.join(model_path, model_file)
|
||||
input_params["model"] = model_file_path
|
||||
|
||||
# 上傳模型 (新增)
|
||||
if device_group:
|
||||
try:
|
||||
print('[上傳模型]')
|
||||
self.model_descriptor = kp.core.load_model_from_file(
|
||||
device_group=device_group,
|
||||
file_path=model_file_path
|
||||
)
|
||||
print(' - 上傳成功')
|
||||
|
||||
# 將模型描述符添加到輸入參數
|
||||
input_params["model_descriptor"] = self.model_descriptor
|
||||
except Exception as e:
|
||||
print(f"上傳模型時發生錯誤: {e}")
|
||||
self.model_descriptor = None
|
||||
|
||||
print("Input parameters:", input_params)
|
||||
|
||||
# Stop existing inference worker if running
|
||||
@ -151,7 +175,7 @@ class InferenceController:
|
||||
self.inference_queue,
|
||||
mode,
|
||||
model_name,
|
||||
min_interval=0.5,
|
||||
min_interval=2,
|
||||
mse_threshold=500,
|
||||
once_mode=once_mode
|
||||
)
|
||||
@ -189,6 +213,7 @@ class InferenceController:
|
||||
print(f"選擇工具時發生錯誤: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def _clear_inference_queue(self):
|
||||
"""清空推論佇列中的所有數據"""
|
||||
@ -258,3 +283,87 @@ class InferenceController:
|
||||
print(f"處理上傳圖片時發生錯誤: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
|
||||
def select_custom_tool(self, tool_config):
|
||||
"""
|
||||
選擇自訂模型工具並配置推論
|
||||
|
||||
Args:
|
||||
tool_config: 包含自訂模型配置的字典,必須包含:
|
||||
- custom_model_path: .nef 模型檔案路徑
|
||||
- custom_scpu_path: SCPU firmware 路徑
|
||||
- custom_ncpu_path: NCPU firmware 路徑
|
||||
"""
|
||||
try:
|
||||
print("選擇自訂模型:", tool_config.get("display_name"))
|
||||
self.current_tool_config = tool_config
|
||||
|
||||
# 清空推論佇列
|
||||
self._clear_inference_queue()
|
||||
|
||||
# 儲存當前工具類型以供下次比較
|
||||
self.previous_tool_config = tool_config
|
||||
|
||||
# 準備輸入參數
|
||||
input_params = tool_config.get("input_parameters", {}).copy()
|
||||
|
||||
# 添加自訂模型路徑
|
||||
input_params["custom_model_path"] = tool_config.get("custom_model_path")
|
||||
input_params["custom_scpu_path"] = tool_config.get("custom_scpu_path")
|
||||
input_params["custom_ncpu_path"] = tool_config.get("custom_ncpu_path")
|
||||
input_params["custom_labels"] = tool_config.get("custom_labels")
|
||||
|
||||
# 取得設備相關設定
|
||||
selected_device = self.device_controller.get_selected_device()
|
||||
if selected_device:
|
||||
if isinstance(selected_device, dict):
|
||||
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
|
||||
else:
|
||||
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
|
||||
else:
|
||||
devices = self.device_controller.connected_devices
|
||||
if devices and len(devices) > 0:
|
||||
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
|
||||
print("警告: 未選擇特定設備,使用第一個可用設備")
|
||||
else:
|
||||
input_params["usb_port_id"] = 0
|
||||
print("警告: 沒有已連接的設備,使用預設 usb_port_id 0")
|
||||
|
||||
print("自訂模型輸入參數:", input_params)
|
||||
|
||||
# 停止現有的推論工作器
|
||||
if self.inference_worker:
|
||||
self.inference_worker.stop()
|
||||
self.inference_worker = None
|
||||
|
||||
# 創建新的自訂推論工作器
|
||||
self.inference_worker = CustomInferenceWorkerThread(
|
||||
self.inference_queue,
|
||||
min_interval=0.5,
|
||||
mse_threshold=500
|
||||
)
|
||||
self.inference_worker.input_params = input_params
|
||||
self.inference_worker.inference_result_signal.connect(
|
||||
self.main_window.handle_inference_result
|
||||
)
|
||||
self.inference_worker.start()
|
||||
print("自訂模型推論工作器已啟動")
|
||||
|
||||
# 啟動相機 (自訂模型預設使用視訊模式)
|
||||
tool_type = tool_config.get("input_info", {}).get("type", "video")
|
||||
if tool_type == "video":
|
||||
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
|
||||
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
|
||||
self.main_window.media_controller.update_image
|
||||
)
|
||||
print("相機已重新連接用於視訊處理")
|
||||
else:
|
||||
self.main_window.media_controller.start_camera()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"選擇自訂模型時發生錯誤: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
511
src/models/custom_inference_worker.py
Normal file
511
src/models/custom_inference_worker.py
Normal file
@ -0,0 +1,511 @@
|
||||
"""
|
||||
Custom Inference Worker
|
||||
使用使用者上傳的自訂模型進行推論
|
||||
前後處理使用 script.py 中定義的 YOLO V5 處理邏輯
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import queue
|
||||
import cv2
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from PyQt5.QtCore import QThread, pyqtSignal
|
||||
|
||||
import kp
|
||||
from kp.KPBaseClass.ValueBase import ValueRepresentBase
|
||||
|
||||
|
||||
# COCO 數據集的類別名稱
|
||||
COCO_CLASSES = [
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',
|
||||
'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat',
|
||||
'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack',
|
||||
'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
|
||||
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
|
||||
'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
|
||||
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator',
|
||||
'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
|
||||
]
|
||||
|
||||
|
||||
class ExampleBoundingBox(ValueRepresentBase):
|
||||
"""Bounding box descriptor."""
|
||||
|
||||
def __init__(self,
|
||||
x1: int = 0,
|
||||
y1: int = 0,
|
||||
x2: int = 0,
|
||||
y2: int = 0,
|
||||
score: float = 0,
|
||||
class_num: int = 0):
|
||||
self.x1 = x1
|
||||
self.y1 = y1
|
||||
self.x2 = x2
|
||||
self.y2 = y2
|
||||
self.score = score
|
||||
self.class_num = class_num
|
||||
|
||||
def get_member_variable_dict(self) -> dict:
|
||||
return {
|
||||
'x1': self.x1,
|
||||
'y1': self.y1,
|
||||
'x2': self.x2,
|
||||
'y2': self.y2,
|
||||
'score': self.score,
|
||||
'class_num': self.class_num
|
||||
}
|
||||
|
||||
|
||||
class ExampleYoloResult(ValueRepresentBase):
|
||||
"""YOLO output result descriptor."""
|
||||
|
||||
def __init__(self,
|
||||
class_count: int = 0,
|
||||
box_count: int = 0,
|
||||
box_list: List[ExampleBoundingBox] = None):
|
||||
self.class_count = class_count
|
||||
self.box_count = box_count
|
||||
self.box_list = box_list if box_list is not None else []
|
||||
|
||||
def get_member_variable_dict(self) -> dict:
|
||||
member_variable_dict = {
|
||||
'class_count': self.class_count,
|
||||
'box_count': self.box_count,
|
||||
'box_list': {}
|
||||
}
|
||||
for idx, box_element in enumerate(self.box_list):
|
||||
member_variable_dict['box_list'][idx] = box_element.get_member_variable_dict()
|
||||
return member_variable_dict
|
||||
|
||||
|
||||
# YOLO 常數
|
||||
YOLO_V3_CELL_BOX_NUM = 3
|
||||
NMS_THRESH_YOLOV5 = 0.5
|
||||
YOLO_MAX_DETECTION_PER_CLASS = 100
|
||||
|
||||
YOLO_V5_ANCHERS = np.array([
|
||||
[[10, 13], [16, 30], [33, 23]],
|
||||
[[30, 61], [62, 45], [59, 119]],
|
||||
[[116, 90], [156, 198], [373, 326]]
|
||||
])
|
||||
|
||||
|
||||
def _sigmoid(x):
|
||||
return 1. / (1. + np.exp(-x))
|
||||
|
||||
|
||||
def _iou(box_src, boxes_dst):
|
||||
max_x1 = np.maximum(box_src[0], boxes_dst[:, 0])
|
||||
max_y1 = np.maximum(box_src[1], boxes_dst[:, 1])
|
||||
min_x2 = np.minimum(box_src[2], boxes_dst[:, 2])
|
||||
min_y2 = np.minimum(box_src[3], boxes_dst[:, 3])
|
||||
|
||||
area_intersection = np.maximum(0, (min_x2 - max_x1)) * np.maximum(0, (min_y2 - max_y1))
|
||||
area_src = (box_src[2] - box_src[0]) * (box_src[3] - box_src[1])
|
||||
area_dst = (boxes_dst[:, 2] - boxes_dst[:, 0]) * (boxes_dst[:, 3] - boxes_dst[:, 1])
|
||||
area_union = area_src + area_dst - area_intersection
|
||||
|
||||
iou = area_intersection / area_union
|
||||
return iou
|
||||
|
||||
|
||||
def _boxes_scale(boxes, hardware_preproc_info: kp.HwPreProcInfo):
|
||||
"""Scale boxes based on hardware preprocessing info."""
|
||||
ratio_w = hardware_preproc_info.img_width / hardware_preproc_info.resized_img_width
|
||||
ratio_h = hardware_preproc_info.img_height / hardware_preproc_info.resized_img_height
|
||||
|
||||
boxes[..., :4] = boxes[..., :4] - np.array([
|
||||
hardware_preproc_info.pad_left, hardware_preproc_info.pad_top,
|
||||
hardware_preproc_info.pad_left, hardware_preproc_info.pad_top
|
||||
])
|
||||
boxes[..., :4] = boxes[..., :4] * np.array([ratio_w, ratio_h, ratio_w, ratio_h])
|
||||
|
||||
return boxes
|
||||
|
||||
|
||||
def post_process_yolo_v5(inference_float_node_output_list: List[kp.InferenceFloatNodeOutput],
|
||||
hardware_preproc_info: kp.HwPreProcInfo,
|
||||
thresh_value: float,
|
||||
with_sigmoid: bool = True) -> ExampleYoloResult:
|
||||
"""YOLO V5 post-processing function."""
|
||||
feature_map_list = []
|
||||
candidate_boxes_list = []
|
||||
|
||||
for i in range(len(inference_float_node_output_list)):
|
||||
anchor_offset = int(inference_float_node_output_list[i].shape[1] / YOLO_V3_CELL_BOX_NUM)
|
||||
feature_map = inference_float_node_output_list[i].ndarray.transpose((0, 2, 3, 1))
|
||||
feature_map = _sigmoid(feature_map) if with_sigmoid else feature_map
|
||||
feature_map = feature_map.reshape((
|
||||
feature_map.shape[0],
|
||||
feature_map.shape[1],
|
||||
feature_map.shape[2],
|
||||
YOLO_V3_CELL_BOX_NUM,
|
||||
anchor_offset
|
||||
))
|
||||
|
||||
ratio_w = hardware_preproc_info.model_input_width / inference_float_node_output_list[i].shape[3]
|
||||
ratio_h = hardware_preproc_info.model_input_height / inference_float_node_output_list[i].shape[2]
|
||||
nrows = inference_float_node_output_list[i].shape[2]
|
||||
ncols = inference_float_node_output_list[i].shape[3]
|
||||
grids = np.expand_dims(np.stack(np.meshgrid(np.arange(ncols), np.arange(nrows)), 2), axis=0)
|
||||
|
||||
for anchor_idx in range(YOLO_V3_CELL_BOX_NUM):
|
||||
feature_map[..., anchor_idx, 0:2] = (
|
||||
feature_map[..., anchor_idx, 0:2] * 2. - 0.5 + grids
|
||||
) * np.array([ratio_h, ratio_w])
|
||||
feature_map[..., anchor_idx, 2:4] = (
|
||||
feature_map[..., anchor_idx, 2:4] * 2
|
||||
) ** 2 * YOLO_V5_ANCHERS[i][anchor_idx]
|
||||
|
||||
feature_map[..., anchor_idx, 0:2] = (
|
||||
feature_map[..., anchor_idx, 0:2] - (feature_map[..., anchor_idx, 2:4] / 2.)
|
||||
)
|
||||
feature_map[..., anchor_idx, 2:4] = (
|
||||
feature_map[..., anchor_idx, 0:2] + feature_map[..., anchor_idx, 2:4]
|
||||
)
|
||||
|
||||
feature_map = _boxes_scale(boxes=feature_map, hardware_preproc_info=hardware_preproc_info)
|
||||
feature_map_list.append(feature_map)
|
||||
|
||||
predict_bboxes = np.concatenate([
|
||||
np.reshape(feature_map, (-1, feature_map.shape[-1])) for feature_map in feature_map_list
|
||||
], axis=0)
|
||||
predict_bboxes[..., 5:] = np.repeat(
|
||||
predict_bboxes[..., 4][..., np.newaxis],
|
||||
predict_bboxes[..., 5:].shape[1],
|
||||
axis=1
|
||||
) * predict_bboxes[..., 5:]
|
||||
predict_bboxes_mask = (predict_bboxes[..., 5:] > thresh_value).sum(axis=1)
|
||||
predict_bboxes = predict_bboxes[predict_bboxes_mask >= 1]
|
||||
|
||||
# NMS
|
||||
for class_idx in range(5, predict_bboxes.shape[1]):
|
||||
candidate_boxes_mask = predict_bboxes[..., class_idx] > thresh_value
|
||||
class_good_box_count = candidate_boxes_mask.sum()
|
||||
if class_good_box_count == 1:
|
||||
candidate_boxes_list.append(
|
||||
ExampleBoundingBox(
|
||||
x1=round(float(predict_bboxes[candidate_boxes_mask, 0][0]), 4),
|
||||
y1=round(float(predict_bboxes[candidate_boxes_mask, 1][0]), 4),
|
||||
x2=round(float(predict_bboxes[candidate_boxes_mask, 2][0]), 4),
|
||||
y2=round(float(predict_bboxes[candidate_boxes_mask, 3][0]), 4),
|
||||
score=round(float(predict_bboxes[candidate_boxes_mask, class_idx][0]), 4),
|
||||
class_num=class_idx - 5
|
||||
)
|
||||
)
|
||||
elif class_good_box_count > 1:
|
||||
candidate_boxes = predict_bboxes[candidate_boxes_mask].copy()
|
||||
candidate_boxes = candidate_boxes[candidate_boxes[:, class_idx].argsort()][::-1]
|
||||
|
||||
for candidate_box_idx in range(candidate_boxes.shape[0] - 1):
|
||||
if 0 != candidate_boxes[candidate_box_idx][class_idx]:
|
||||
remove_mask = _iou(
|
||||
box_src=candidate_boxes[candidate_box_idx],
|
||||
boxes_dst=candidate_boxes[candidate_box_idx + 1:]
|
||||
) > NMS_THRESH_YOLOV5
|
||||
candidate_boxes[candidate_box_idx + 1:][remove_mask, class_idx] = 0
|
||||
|
||||
good_count = 0
|
||||
for candidate_box_idx in range(candidate_boxes.shape[0]):
|
||||
if candidate_boxes[candidate_box_idx, class_idx] > 0:
|
||||
candidate_boxes_list.append(
|
||||
ExampleBoundingBox(
|
||||
x1=round(float(candidate_boxes[candidate_box_idx, 0]), 4),
|
||||
y1=round(float(candidate_boxes[candidate_box_idx, 1]), 4),
|
||||
x2=round(float(candidate_boxes[candidate_box_idx, 2]), 4),
|
||||
y2=round(float(candidate_boxes[candidate_box_idx, 3]), 4),
|
||||
score=round(float(candidate_boxes[candidate_box_idx, class_idx]), 4),
|
||||
class_num=class_idx - 5
|
||||
)
|
||||
)
|
||||
good_count += 1
|
||||
|
||||
if YOLO_MAX_DETECTION_PER_CLASS == good_count:
|
||||
break
|
||||
|
||||
for idx, candidate_boxes in enumerate(candidate_boxes_list):
|
||||
candidate_boxes_list[idx].x1 = 0 if (candidate_boxes_list[idx].x1 + 0.5 < 0) else int(
|
||||
candidate_boxes_list[idx].x1 + 0.5)
|
||||
candidate_boxes_list[idx].y1 = 0 if (candidate_boxes_list[idx].y1 + 0.5 < 0) else int(
|
||||
candidate_boxes_list[idx].y1 + 0.5)
|
||||
candidate_boxes_list[idx].x2 = int(hardware_preproc_info.img_width - 1) if (
|
||||
candidate_boxes_list[idx].x2 + 0.5 > hardware_preproc_info.img_width - 1
|
||||
) else int(candidate_boxes_list[idx].x2 + 0.5)
|
||||
candidate_boxes_list[idx].y2 = int(hardware_preproc_info.img_height - 1) if (
|
||||
candidate_boxes_list[idx].y2 + 0.5 > hardware_preproc_info.img_height - 1
|
||||
) else int(candidate_boxes_list[idx].y2 + 0.5)
|
||||
|
||||
return ExampleYoloResult(
|
||||
class_count=predict_bboxes.shape[1] - 5 if len(predict_bboxes) > 0 else 0,
|
||||
box_count=len(candidate_boxes_list),
|
||||
box_list=candidate_boxes_list
|
||||
)
|
||||
|
||||
|
||||
def preprocess_frame(frame, target_size=640):
|
||||
"""
|
||||
預處理影像
|
||||
|
||||
Args:
|
||||
frame: 原始 BGR 影像
|
||||
target_size: 目標大小 (default 640 for YOLO)
|
||||
|
||||
Returns:
|
||||
processed_frame: 處理後的影像 (BGR565 格式)
|
||||
original_width: 原始寬度
|
||||
original_height: 原始高度
|
||||
"""
|
||||
if frame is None:
|
||||
raise Exception("輸入的 frame 為 None")
|
||||
|
||||
original_height, original_width = frame.shape[:2]
|
||||
|
||||
# 調整大小
|
||||
resized_frame = cv2.resize(frame, (target_size, target_size))
|
||||
|
||||
# 轉換為 BGR565 格式
|
||||
frame_bgr565 = cv2.cvtColor(resized_frame, cv2.COLOR_BGR2BGR565)
|
||||
|
||||
return frame_bgr565, original_width, original_height
|
||||
|
||||
|
||||
def postprocess(output_list, hw_preproc_info, original_width, original_height, target_size=640, thresh=0.2):
|
||||
"""
|
||||
後處理 YOLO 輸出
|
||||
|
||||
Args:
|
||||
output_list: 模型輸出節點列表
|
||||
hw_preproc_info: 硬體預處理資訊
|
||||
original_width: 原始影像寬度
|
||||
original_height: 原始影像高度
|
||||
target_size: 縮放目標大小
|
||||
thresh: 閾值
|
||||
|
||||
Returns:
|
||||
yolo_result: YOLO 偵測結果
|
||||
"""
|
||||
yolo_result = post_process_yolo_v5(
|
||||
inference_float_node_output_list=output_list,
|
||||
hardware_preproc_info=hw_preproc_info,
|
||||
thresh_value=thresh
|
||||
)
|
||||
|
||||
# 調整邊界框座標以符合原始尺寸
|
||||
width_ratio = original_width / target_size
|
||||
height_ratio = original_height / target_size
|
||||
|
||||
for box in yolo_result.box_list:
|
||||
box.x1 = int(box.x1 * width_ratio)
|
||||
box.y1 = int(box.y1 * height_ratio)
|
||||
box.x2 = int(box.x2 * width_ratio)
|
||||
box.y2 = int(box.y2 * height_ratio)
|
||||
|
||||
return yolo_result
|
||||
|
||||
|
||||
class CustomInferenceWorkerThread(QThread):
|
||||
"""
|
||||
自訂模型推論工作線程
|
||||
使用使用者上傳的模型和韌體進行推論
|
||||
"""
|
||||
inference_result_signal = pyqtSignal(object)
|
||||
|
||||
def __init__(self, frame_queue, min_interval=0.5, mse_threshold=500):
|
||||
super().__init__()
|
||||
self.frame_queue = frame_queue
|
||||
self.min_interval = min_interval
|
||||
self.mse_threshold = mse_threshold
|
||||
self._running = True
|
||||
self.last_inference_time = 0
|
||||
self.last_frame = None
|
||||
self.cached_result = None
|
||||
self.input_params = {}
|
||||
|
||||
# 設備和模型相關
|
||||
self.device_group = None
|
||||
self.model_descriptor = None
|
||||
self.is_initialized = False
|
||||
|
||||
# 自訂標籤
|
||||
self.custom_labels = None
|
||||
|
||||
def initialize_device(self):
|
||||
"""初始化設備、上傳韌體和模型"""
|
||||
try:
|
||||
model_path = self.input_params.get("custom_model_path")
|
||||
scpu_path = self.input_params.get("custom_scpu_path")
|
||||
ncpu_path = self.input_params.get("custom_ncpu_path")
|
||||
port_id = self.input_params.get("usb_port_id", 0)
|
||||
|
||||
# 載入自訂標籤
|
||||
self.custom_labels = self.input_params.get("custom_labels")
|
||||
if self.custom_labels:
|
||||
print(f'[自訂標籤] 已載入 {len(self.custom_labels)} 個類別')
|
||||
else:
|
||||
print('[自訂標籤] 未提供,使用預設 COCO 類別')
|
||||
|
||||
if not all([model_path, scpu_path, ncpu_path]):
|
||||
print("缺少必要的檔案路徑")
|
||||
return False
|
||||
|
||||
# 連接設備
|
||||
print('[連接裝置]')
|
||||
self.device_group = kp.core.connect_devices(usb_port_ids=[port_id])
|
||||
kp.core.set_timeout(device_group=self.device_group, milliseconds=5000)
|
||||
print(' - 連接成功')
|
||||
|
||||
# 上傳韌體
|
||||
print('[上傳韌體]')
|
||||
kp.core.load_firmware_from_file(self.device_group, scpu_path, ncpu_path)
|
||||
print(' - 韌體上傳成功')
|
||||
|
||||
# 上傳模型
|
||||
print('[上傳模型]')
|
||||
self.model_descriptor = kp.core.load_model_from_file(
|
||||
self.device_group,
|
||||
file_path=model_path
|
||||
)
|
||||
print(' - 模型上傳成功')
|
||||
|
||||
self.is_initialized = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"初始化設備時發生錯誤: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def run_single_inference(self, frame):
|
||||
"""執行單次推論"""
|
||||
try:
|
||||
if not self.is_initialized:
|
||||
if not self.initialize_device():
|
||||
return None
|
||||
|
||||
# 預處理
|
||||
img_processed, original_width, original_height = preprocess_frame(frame)
|
||||
|
||||
# 建立推論描述符
|
||||
descriptor = kp.GenericImageInferenceDescriptor(
|
||||
model_id=self.model_descriptor.models[0].id,
|
||||
inference_number=0,
|
||||
input_node_image_list=[
|
||||
kp.GenericInputNodeImage(
|
||||
image=img_processed,
|
||||
image_format=kp.ImageFormat.KP_IMAGE_FORMAT_RGB565,
|
||||
resize_mode=kp.ResizeMode.KP_RESIZE_ENABLE,
|
||||
padding_mode=kp.PaddingMode.KP_PADDING_CORNER,
|
||||
normalize_mode=kp.NormalizeMode.KP_NORMALIZE_KNERON
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# 執行推論
|
||||
kp.inference.generic_image_inference_send(self.device_group, descriptor)
|
||||
result = kp.inference.generic_image_inference_receive(self.device_group)
|
||||
|
||||
# 取得輸出節點
|
||||
output_list = []
|
||||
for node_idx in range(result.header.num_output_node):
|
||||
node_output = kp.inference.generic_inference_retrieve_float_node(
|
||||
node_idx=node_idx,
|
||||
generic_raw_result=result,
|
||||
channels_ordering=kp.ChannelOrdering.KP_CHANNEL_ORDERING_CHW
|
||||
)
|
||||
output_list.append(node_output)
|
||||
|
||||
# 後處理
|
||||
yolo_result = postprocess(
|
||||
output_list,
|
||||
result.header.hw_pre_proc_info_list[0],
|
||||
original_width,
|
||||
original_height
|
||||
)
|
||||
|
||||
# 轉換為標準格式
|
||||
bounding_boxes = [
|
||||
[box.x1, box.y1, box.x2, box.y2] for box in yolo_result.box_list
|
||||
]
|
||||
|
||||
# 使用自訂標籤或預設 COCO 類別
|
||||
labels_to_use = self.custom_labels if self.custom_labels else COCO_CLASSES
|
||||
|
||||
results = []
|
||||
for box in yolo_result.box_list:
|
||||
if 0 <= box.class_num < len(labels_to_use):
|
||||
results.append(labels_to_use[box.class_num])
|
||||
else:
|
||||
results.append(f"class_{box.class_num}")
|
||||
|
||||
return {
|
||||
"num_boxes": len(yolo_result.box_list),
|
||||
"bounding boxes": bounding_boxes,
|
||||
"results": results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"推論時發生錯誤: {e}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def run(self):
|
||||
"""主執行循環"""
|
||||
while self._running:
|
||||
try:
|
||||
frame = self.frame_queue.get(timeout=0.1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - self.last_inference_time < self.min_interval:
|
||||
continue
|
||||
|
||||
# MSE 檢測以優化效能
|
||||
if self.last_frame is not None:
|
||||
if frame.shape != self.last_frame.shape:
|
||||
self.last_frame = None
|
||||
self.cached_result = None
|
||||
else:
|
||||
try:
|
||||
mse = np.mean((frame.astype(np.float32) - self.last_frame.astype(np.float32)) ** 2)
|
||||
if mse < self.mse_threshold and self.cached_result is not None:
|
||||
self.inference_result_signal.emit(self.cached_result)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"計算 MSE 時發生錯誤: {e}")
|
||||
self.last_frame = None
|
||||
self.cached_result = None
|
||||
|
||||
# 執行推論
|
||||
result = self.run_single_inference(frame)
|
||||
|
||||
self.last_inference_time = current_time
|
||||
self.last_frame = frame.copy()
|
||||
self.cached_result = result
|
||||
|
||||
if result is not None:
|
||||
self.inference_result_signal.emit(result)
|
||||
|
||||
# 斷開設備連接
|
||||
self.cleanup()
|
||||
self.quit()
|
||||
|
||||
def cleanup(self):
|
||||
"""清理資源"""
|
||||
try:
|
||||
if self.device_group is not None:
|
||||
kp.core.disconnect_devices(self.device_group)
|
||||
print('[已斷開裝置]')
|
||||
self.device_group = None
|
||||
except Exception as e:
|
||||
print(f"清理資源時發生錯誤: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""停止工作線程"""
|
||||
self._running = False
|
||||
self.wait()
|
||||
self.cleanup()
|
||||
399
src/views/components/custom_model_block.py
Normal file
399
src/views/components/custom_model_block.py
Normal file
@ -0,0 +1,399 @@
|
||||
"""
|
||||
Custom Model Upload Block Component
|
||||
允許使用者上傳自己的 .nef 模型和 firmware (.bin) 檔案
|
||||
"""
|
||||
import os
|
||||
from PyQt5.QtWidgets import (
|
||||
QFrame, QVBoxLayout, QHBoxLayout, QLabel,
|
||||
QPushButton, QWidget, QFileDialog
|
||||
)
|
||||
from PyQt5.QtSvg import QSvgWidget
|
||||
from PyQt5.QtCore import Qt, pyqtSignal, QObject
|
||||
|
||||
from src.config import SECONDARY_COLOR, UXUI_ASSETS, BUTTON_STYLE
|
||||
|
||||
|
||||
class CustomModelSignals(QObject):
|
||||
"""用於發送自訂模型選擇信號"""
|
||||
model_selected = pyqtSignal(dict)
|
||||
|
||||
|
||||
class FileUploadRow(QWidget):
|
||||
"""檔案上傳行元件,顯示標籤和上傳按鈕"""
|
||||
|
||||
def __init__(self, label_text, file_filter, parent=None):
|
||||
super().__init__(parent)
|
||||
self.file_path = None
|
||||
self.file_filter = file_filter
|
||||
self.label_text = label_text
|
||||
self.init_ui()
|
||||
|
||||
def init_ui(self):
|
||||
layout = QHBoxLayout(self)
|
||||
layout.setContentsMargins(0, 2, 0, 2)
|
||||
layout.setSpacing(5)
|
||||
|
||||
# 標籤
|
||||
self.label = QLabel(self.label_text)
|
||||
self.label.setStyleSheet("color: white; font-size: 11px; border: none; background: transparent;")
|
||||
self.label.setFixedWidth(50)
|
||||
layout.addWidget(self.label)
|
||||
|
||||
# 檔案名稱顯示
|
||||
self.file_label = QLabel("未選擇")
|
||||
self.file_label.setStyleSheet("""
|
||||
color: #aaa;
|
||||
font-size: 10px;
|
||||
border: 1px solid #555;
|
||||
border-radius: 3px;
|
||||
padding: 2px 5px;
|
||||
background: rgba(0,0,0,0.2);
|
||||
""")
|
||||
self.file_label.setFixedHeight(22)
|
||||
layout.addWidget(self.file_label, 1)
|
||||
|
||||
# 上傳按鈕
|
||||
self.upload_btn = QPushButton("...")
|
||||
self.upload_btn.setFixedSize(28, 22)
|
||||
self.upload_btn.setStyleSheet("""
|
||||
QPushButton {
|
||||
background: rgba(255,255,255,0.1);
|
||||
color: white;
|
||||
border: 1px solid white;
|
||||
border-radius: 3px;
|
||||
font-size: 10px;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background: rgba(255,255,255,0.2);
|
||||
}
|
||||
""")
|
||||
self.upload_btn.clicked.connect(self.select_file)
|
||||
layout.addWidget(self.upload_btn)
|
||||
|
||||
def select_file(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self,
|
||||
f"選擇 {self.label_text} 檔案",
|
||||
"",
|
||||
self.file_filter
|
||||
)
|
||||
if file_path:
|
||||
self.file_path = file_path
|
||||
# 只顯示檔案名稱,避免過長
|
||||
file_name = os.path.basename(file_path)
|
||||
# 截斷過長的檔案名稱
|
||||
if len(file_name) > 15:
|
||||
file_name = file_name[:12] + "..."
|
||||
self.file_label.setText(file_name)
|
||||
self.file_label.setStyleSheet("""
|
||||
color: #4CAF50;
|
||||
font-size: 10px;
|
||||
border: 1px solid #4CAF50;
|
||||
border-radius: 3px;
|
||||
padding: 2px 5px;
|
||||
background: rgba(76, 175, 80, 0.1);
|
||||
""")
|
||||
self.file_label.setToolTip(file_path)
|
||||
|
||||
def get_file_path(self):
|
||||
return self.file_path
|
||||
|
||||
def reset(self):
|
||||
self.file_path = None
|
||||
self.file_label.setText("未選擇")
|
||||
self.file_label.setStyleSheet("""
|
||||
color: #aaa;
|
||||
font-size: 10px;
|
||||
border: 1px solid #555;
|
||||
border-radius: 3px;
|
||||
padding: 2px 5px;
|
||||
background: rgba(0,0,0,0.2);
|
||||
""")
|
||||
self.file_label.setToolTip("")
|
||||
|
||||
|
||||
def create_custom_model_block(parent, inference_controller):
|
||||
"""
|
||||
創建自訂模型上傳區塊
|
||||
|
||||
Args:
|
||||
parent: 父元件
|
||||
inference_controller: 推論控制器實例
|
||||
|
||||
Returns:
|
||||
QFrame: 包含上傳功能的區塊
|
||||
"""
|
||||
# 創建信號物件
|
||||
signals = CustomModelSignals()
|
||||
|
||||
# 主框架
|
||||
block_frame = QFrame(parent)
|
||||
block_frame.setStyleSheet(f"""
|
||||
QFrame {{
|
||||
border: none;
|
||||
background: {SECONDARY_COLOR};
|
||||
border-radius: 15px;
|
||||
}}
|
||||
""")
|
||||
block_frame.setFixedHeight(270) # 增加高度以容納 Labels 欄位
|
||||
block_frame.setFixedWidth(240)
|
||||
|
||||
block_layout = QVBoxLayout(block_frame)
|
||||
block_layout.setContentsMargins(15, 10, 15, 10)
|
||||
block_layout.setSpacing(5)
|
||||
|
||||
# 標題行
|
||||
title_layout = QHBoxLayout()
|
||||
title_layout.setSpacing(8)
|
||||
|
||||
# 圖示
|
||||
icon_path = os.path.join(UXUI_ASSETS, "Assets_svg/ic_window_toolbox.svg")
|
||||
if os.path.exists(icon_path):
|
||||
toolbox_icon = QSvgWidget(icon_path)
|
||||
toolbox_icon.setFixedSize(28, 28)
|
||||
title_layout.addWidget(toolbox_icon)
|
||||
|
||||
title_label = QLabel("Custom Model")
|
||||
title_label.setStyleSheet("color: white; font-size: 16px; font-weight: bold; border: none; background: transparent;")
|
||||
title_layout.addWidget(title_label)
|
||||
title_layout.addStretch()
|
||||
|
||||
block_layout.addLayout(title_layout)
|
||||
|
||||
# 分隔線
|
||||
separator = QFrame()
|
||||
separator.setFrameShape(QFrame.HLine)
|
||||
separator.setStyleSheet("background-color: rgba(255,255,255,0.2); border: none;")
|
||||
separator.setFixedHeight(1)
|
||||
block_layout.addWidget(separator)
|
||||
|
||||
# Model 檔案上傳 (.nef)
|
||||
model_row = FileUploadRow("Model", "NEF Files (*.nef);;All Files (*)", block_frame)
|
||||
block_layout.addWidget(model_row)
|
||||
|
||||
# SCPU Firmware 上傳 (.bin)
|
||||
scpu_row = FileUploadRow("SCPU", "Binary Files (*.bin);;All Files (*)", block_frame)
|
||||
block_layout.addWidget(scpu_row)
|
||||
|
||||
# NCPU Firmware 上傳 (.bin)
|
||||
ncpu_row = FileUploadRow("NCPU", "Binary Files (*.bin);;All Files (*)", block_frame)
|
||||
block_layout.addWidget(ncpu_row)
|
||||
|
||||
# Labels 上傳 (.txt) - 用於自訂輸出類別
|
||||
labels_row = FileUploadRow("Labels", "Text Files (*.txt);;All Files (*)", block_frame)
|
||||
block_layout.addWidget(labels_row)
|
||||
|
||||
# 狀態標籤 - 顯示當前選擇的模型名稱
|
||||
status_label = QLabel("")
|
||||
status_label.setStyleSheet("""
|
||||
color: #4CAF50;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
status_label.setAlignment(Qt.AlignCenter)
|
||||
status_label.setWordWrap(True)
|
||||
block_layout.addWidget(status_label)
|
||||
|
||||
# 按鈕區域
|
||||
btn_layout = QHBoxLayout()
|
||||
btn_layout.setSpacing(8)
|
||||
|
||||
# 停止按鈕
|
||||
stop_btn = QPushButton("Stop")
|
||||
stop_btn.setFixedHeight(32)
|
||||
stop_btn.setStyleSheet("""
|
||||
QPushButton {
|
||||
background: transparent;
|
||||
color: #ff6b6b;
|
||||
border: 1px solid #ff6b6b;
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
font-size: 11px;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background-color: rgba(255, 107, 107, 0.2);
|
||||
}
|
||||
""")
|
||||
btn_layout.addWidget(stop_btn)
|
||||
|
||||
# 執行按鈕
|
||||
run_btn = QPushButton("Run Inference")
|
||||
run_btn.setFixedHeight(32)
|
||||
run_btn.setStyleSheet("""
|
||||
QPushButton {
|
||||
background: rgba(76, 175, 80, 0.3);
|
||||
color: white;
|
||||
border: 1px solid #4CAF50;
|
||||
border-radius: 8px;
|
||||
padding: 3px 8px;
|
||||
font-size: 11px;
|
||||
}
|
||||
QPushButton:hover {
|
||||
background-color: rgba(76, 175, 80, 0.5);
|
||||
}
|
||||
QPushButton:disabled {
|
||||
background: rgba(128, 128, 128, 0.2);
|
||||
color: #666;
|
||||
border: 1px solid #666;
|
||||
}
|
||||
""")
|
||||
btn_layout.addWidget(run_btn)
|
||||
|
||||
block_layout.addLayout(btn_layout)
|
||||
|
||||
# 儲存元件參考到 frame
|
||||
block_frame.model_row = model_row
|
||||
block_frame.scpu_row = scpu_row
|
||||
block_frame.ncpu_row = ncpu_row
|
||||
block_frame.labels_row = labels_row
|
||||
block_frame.status_label = status_label
|
||||
block_frame.signals = signals
|
||||
|
||||
def stop_inference_and_disconnect():
|
||||
"""停止推論並斷開 dongle 連線"""
|
||||
# 停止推論工作器並斷開 dongle
|
||||
if hasattr(inference_controller, 'inference_worker') and inference_controller.inference_worker is not None:
|
||||
print("[Stop] 停止推論並斷開 dongle...")
|
||||
inference_controller.stop_inference()
|
||||
print("[Stop] 推論已停止,dongle 已斷開")
|
||||
|
||||
status_label.setText("Stopped & Disconnected")
|
||||
status_label.setStyleSheet("""
|
||||
color: #ff6b6b;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
else:
|
||||
# 如果沒有正在運行的推論,只重置檔案選擇
|
||||
model_row.reset()
|
||||
scpu_row.reset()
|
||||
ncpu_row.reset()
|
||||
labels_row.reset()
|
||||
status_label.setText("")
|
||||
status_label.setStyleSheet("""
|
||||
color: #4CAF50;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
|
||||
def run_inference():
|
||||
"""執行自訂模型推論"""
|
||||
model_path = model_row.get_file_path()
|
||||
scpu_path = scpu_row.get_file_path()
|
||||
ncpu_path = ncpu_row.get_file_path()
|
||||
labels_path = labels_row.get_file_path()
|
||||
|
||||
# 檢查必要檔案
|
||||
if not model_path:
|
||||
status_label.setText("Please select Model (.nef)")
|
||||
status_label.setStyleSheet("""
|
||||
color: #ff6b6b;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
return
|
||||
|
||||
if not scpu_path:
|
||||
status_label.setText("Please select SCPU (.bin)")
|
||||
status_label.setStyleSheet("""
|
||||
color: #ff6b6b;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
return
|
||||
|
||||
if not ncpu_path:
|
||||
status_label.setText("Please select NCPU (.bin)")
|
||||
status_label.setStyleSheet("""
|
||||
color: #ff6b6b;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
return
|
||||
|
||||
# 顯示模型名稱
|
||||
model_name = os.path.basename(model_path)
|
||||
status_label.setText(f"Running: {model_name}")
|
||||
status_label.setStyleSheet("""
|
||||
color: #4CAF50;
|
||||
font-size: 11px;
|
||||
border: none;
|
||||
background: transparent;
|
||||
padding: 2px 0;
|
||||
""")
|
||||
|
||||
# 讀取 labels 檔案(如果有提供)
|
||||
custom_labels = None
|
||||
if labels_path and os.path.exists(labels_path):
|
||||
try:
|
||||
with open(labels_path, 'r', encoding='utf-8') as f:
|
||||
custom_labels = []
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# 支援三種格式:
|
||||
# 1. "0, person" (數字 + 逗號 + 空格 + 標籤)
|
||||
# 2. "0 person" (數字 + 空格 + 標籤)
|
||||
# 3. "person" (純標籤)
|
||||
if ',' in line:
|
||||
# 格式: "0, person"
|
||||
parts = line.split(',', 1)
|
||||
if len(parts) == 2 and parts[0].strip().isdigit():
|
||||
custom_labels.append(parts[1].strip())
|
||||
else:
|
||||
custom_labels.append(line)
|
||||
else:
|
||||
parts = line.split(maxsplit=1)
|
||||
if len(parts) == 2 and parts[0].isdigit():
|
||||
# 格式: "0 person"
|
||||
custom_labels.append(parts[1])
|
||||
else:
|
||||
# 格式: "person"
|
||||
custom_labels.append(line)
|
||||
print(f"[Custom Model] 載入 {len(custom_labels)} 個類別標籤")
|
||||
except Exception as e:
|
||||
print(f"[Custom Model] 讀取 labels 檔案失敗: {e}")
|
||||
|
||||
# 創建自訂模型配置
|
||||
custom_config = {
|
||||
"mode": "custom",
|
||||
"model_name": "user_uploaded",
|
||||
"display_name": model_name,
|
||||
"description": "User uploaded custom model",
|
||||
"is_custom": True,
|
||||
"custom_model_path": model_path,
|
||||
"custom_scpu_path": scpu_path,
|
||||
"custom_ncpu_path": ncpu_path,
|
||||
"custom_labels_path": labels_path,
|
||||
"custom_labels": custom_labels, # 類別標籤列表
|
||||
"input_info": {
|
||||
"type": "video" # 預設使用視訊模式
|
||||
},
|
||||
"input_parameters": {}
|
||||
}
|
||||
|
||||
# 發送信號或直接調用控制器
|
||||
if hasattr(inference_controller, 'select_custom_tool'):
|
||||
inference_controller.select_custom_tool(custom_config)
|
||||
else:
|
||||
# 如果沒有專用方法,使用通用的 select_tool
|
||||
inference_controller.select_tool(custom_config)
|
||||
|
||||
# 連接按鈕事件
|
||||
stop_btn.clicked.connect(stop_inference_and_disconnect)
|
||||
run_btn.clicked.connect(run_inference)
|
||||
|
||||
return block_frame
|
||||
Loading…
x
Reference in New Issue
Block a user