KNEO-Academy/src/views/components/custom_model_block.py
HuangMason320 09156cce94 Feature: Add custom model upload and inference support
- device_controller: Add connect_device, disconnect_device, get_device_group methods
- inference_controller: Add select_custom_tool method for custom model inference
- custom_inference_worker: New worker thread for custom model inference with YOLO V5 post-processing
- custom_model_block: New UI component for uploading custom .nef model and firmware files

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2025-12-27 02:38:18 +08:00

400 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Custom Model Upload Block Component
允許使用者上傳自己的 .nef 模型和 firmware (.bin) 檔案
"""
import os
from PyQt5.QtWidgets import (
QFrame, QVBoxLayout, QHBoxLayout, QLabel,
QPushButton, QWidget, QFileDialog
)
from PyQt5.QtSvg import QSvgWidget
from PyQt5.QtCore import Qt, pyqtSignal, QObject
from src.config import SECONDARY_COLOR, UXUI_ASSETS, BUTTON_STYLE
class CustomModelSignals(QObject):
"""用於發送自訂模型選擇信號"""
model_selected = pyqtSignal(dict)
class FileUploadRow(QWidget):
"""檔案上傳行元件,顯示標籤和上傳按鈕"""
def __init__(self, label_text, file_filter, parent=None):
super().__init__(parent)
self.file_path = None
self.file_filter = file_filter
self.label_text = label_text
self.init_ui()
def init_ui(self):
layout = QHBoxLayout(self)
layout.setContentsMargins(0, 2, 0, 2)
layout.setSpacing(5)
# 標籤
self.label = QLabel(self.label_text)
self.label.setStyleSheet("color: white; font-size: 11px; border: none; background: transparent;")
self.label.setFixedWidth(50)
layout.addWidget(self.label)
# 檔案名稱顯示
self.file_label = QLabel("未選擇")
self.file_label.setStyleSheet("""
color: #aaa;
font-size: 10px;
border: 1px solid #555;
border-radius: 3px;
padding: 2px 5px;
background: rgba(0,0,0,0.2);
""")
self.file_label.setFixedHeight(22)
layout.addWidget(self.file_label, 1)
# 上傳按鈕
self.upload_btn = QPushButton("...")
self.upload_btn.setFixedSize(28, 22)
self.upload_btn.setStyleSheet("""
QPushButton {
background: rgba(255,255,255,0.1);
color: white;
border: 1px solid white;
border-radius: 3px;
font-size: 10px;
}
QPushButton:hover {
background: rgba(255,255,255,0.2);
}
""")
self.upload_btn.clicked.connect(self.select_file)
layout.addWidget(self.upload_btn)
def select_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self,
f"選擇 {self.label_text} 檔案",
"",
self.file_filter
)
if file_path:
self.file_path = file_path
# 只顯示檔案名稱,避免過長
file_name = os.path.basename(file_path)
# 截斷過長的檔案名稱
if len(file_name) > 15:
file_name = file_name[:12] + "..."
self.file_label.setText(file_name)
self.file_label.setStyleSheet("""
color: #4CAF50;
font-size: 10px;
border: 1px solid #4CAF50;
border-radius: 3px;
padding: 2px 5px;
background: rgba(76, 175, 80, 0.1);
""")
self.file_label.setToolTip(file_path)
def get_file_path(self):
return self.file_path
def reset(self):
self.file_path = None
self.file_label.setText("未選擇")
self.file_label.setStyleSheet("""
color: #aaa;
font-size: 10px;
border: 1px solid #555;
border-radius: 3px;
padding: 2px 5px;
background: rgba(0,0,0,0.2);
""")
self.file_label.setToolTip("")
def create_custom_model_block(parent, inference_controller):
"""
創建自訂模型上傳區塊
Args:
parent: 父元件
inference_controller: 推論控制器實例
Returns:
QFrame: 包含上傳功能的區塊
"""
# 創建信號物件
signals = CustomModelSignals()
# 主框架
block_frame = QFrame(parent)
block_frame.setStyleSheet(f"""
QFrame {{
border: none;
background: {SECONDARY_COLOR};
border-radius: 15px;
}}
""")
block_frame.setFixedHeight(270) # 增加高度以容納 Labels 欄位
block_frame.setFixedWidth(240)
block_layout = QVBoxLayout(block_frame)
block_layout.setContentsMargins(15, 10, 15, 10)
block_layout.setSpacing(5)
# 標題行
title_layout = QHBoxLayout()
title_layout.setSpacing(8)
# 圖示
icon_path = os.path.join(UXUI_ASSETS, "Assets_svg/ic_window_toolbox.svg")
if os.path.exists(icon_path):
toolbox_icon = QSvgWidget(icon_path)
toolbox_icon.setFixedSize(28, 28)
title_layout.addWidget(toolbox_icon)
title_label = QLabel("Custom Model")
title_label.setStyleSheet("color: white; font-size: 16px; font-weight: bold; border: none; background: transparent;")
title_layout.addWidget(title_label)
title_layout.addStretch()
block_layout.addLayout(title_layout)
# 分隔線
separator = QFrame()
separator.setFrameShape(QFrame.HLine)
separator.setStyleSheet("background-color: rgba(255,255,255,0.2); border: none;")
separator.setFixedHeight(1)
block_layout.addWidget(separator)
# Model 檔案上傳 (.nef)
model_row = FileUploadRow("Model", "NEF Files (*.nef);;All Files (*)", block_frame)
block_layout.addWidget(model_row)
# SCPU Firmware 上傳 (.bin)
scpu_row = FileUploadRow("SCPU", "Binary Files (*.bin);;All Files (*)", block_frame)
block_layout.addWidget(scpu_row)
# NCPU Firmware 上傳 (.bin)
ncpu_row = FileUploadRow("NCPU", "Binary Files (*.bin);;All Files (*)", block_frame)
block_layout.addWidget(ncpu_row)
# Labels 上傳 (.txt) - 用於自訂輸出類別
labels_row = FileUploadRow("Labels", "Text Files (*.txt);;All Files (*)", block_frame)
block_layout.addWidget(labels_row)
# 狀態標籤 - 顯示當前選擇的模型名稱
status_label = QLabel("")
status_label.setStyleSheet("""
color: #4CAF50;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
status_label.setAlignment(Qt.AlignCenter)
status_label.setWordWrap(True)
block_layout.addWidget(status_label)
# 按鈕區域
btn_layout = QHBoxLayout()
btn_layout.setSpacing(8)
# 停止按鈕
stop_btn = QPushButton("Stop")
stop_btn.setFixedHeight(32)
stop_btn.setStyleSheet("""
QPushButton {
background: transparent;
color: #ff6b6b;
border: 1px solid #ff6b6b;
border-radius: 8px;
padding: 3px 8px;
font-size: 11px;
}
QPushButton:hover {
background-color: rgba(255, 107, 107, 0.2);
}
""")
btn_layout.addWidget(stop_btn)
# 執行按鈕
run_btn = QPushButton("Run Inference")
run_btn.setFixedHeight(32)
run_btn.setStyleSheet("""
QPushButton {
background: rgba(76, 175, 80, 0.3);
color: white;
border: 1px solid #4CAF50;
border-radius: 8px;
padding: 3px 8px;
font-size: 11px;
}
QPushButton:hover {
background-color: rgba(76, 175, 80, 0.5);
}
QPushButton:disabled {
background: rgba(128, 128, 128, 0.2);
color: #666;
border: 1px solid #666;
}
""")
btn_layout.addWidget(run_btn)
block_layout.addLayout(btn_layout)
# 儲存元件參考到 frame
block_frame.model_row = model_row
block_frame.scpu_row = scpu_row
block_frame.ncpu_row = ncpu_row
block_frame.labels_row = labels_row
block_frame.status_label = status_label
block_frame.signals = signals
def stop_inference_and_disconnect():
"""停止推論並斷開 dongle 連線"""
# 停止推論工作器並斷開 dongle
if hasattr(inference_controller, 'inference_worker') and inference_controller.inference_worker is not None:
print("[Stop] 停止推論並斷開 dongle...")
inference_controller.stop_inference()
print("[Stop] 推論已停止dongle 已斷開")
status_label.setText("Stopped & Disconnected")
status_label.setStyleSheet("""
color: #ff6b6b;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
else:
# 如果沒有正在運行的推論,只重置檔案選擇
model_row.reset()
scpu_row.reset()
ncpu_row.reset()
labels_row.reset()
status_label.setText("")
status_label.setStyleSheet("""
color: #4CAF50;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
def run_inference():
"""執行自訂模型推論"""
model_path = model_row.get_file_path()
scpu_path = scpu_row.get_file_path()
ncpu_path = ncpu_row.get_file_path()
labels_path = labels_row.get_file_path()
# 檢查必要檔案
if not model_path:
status_label.setText("Please select Model (.nef)")
status_label.setStyleSheet("""
color: #ff6b6b;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
return
if not scpu_path:
status_label.setText("Please select SCPU (.bin)")
status_label.setStyleSheet("""
color: #ff6b6b;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
return
if not ncpu_path:
status_label.setText("Please select NCPU (.bin)")
status_label.setStyleSheet("""
color: #ff6b6b;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
return
# 顯示模型名稱
model_name = os.path.basename(model_path)
status_label.setText(f"Running: {model_name}")
status_label.setStyleSheet("""
color: #4CAF50;
font-size: 11px;
border: none;
background: transparent;
padding: 2px 0;
""")
# 讀取 labels 檔案(如果有提供)
custom_labels = None
if labels_path and os.path.exists(labels_path):
try:
with open(labels_path, 'r', encoding='utf-8') as f:
custom_labels = []
for line in f:
line = line.strip()
if not line:
continue
# 支援三種格式:
# 1. "0, person" (數字 + 逗號 + 空格 + 標籤)
# 2. "0 person" (數字 + 空格 + 標籤)
# 3. "person" (純標籤)
if ',' in line:
# 格式: "0, person"
parts = line.split(',', 1)
if len(parts) == 2 and parts[0].strip().isdigit():
custom_labels.append(parts[1].strip())
else:
custom_labels.append(line)
else:
parts = line.split(maxsplit=1)
if len(parts) == 2 and parts[0].isdigit():
# 格式: "0 person"
custom_labels.append(parts[1])
else:
# 格式: "person"
custom_labels.append(line)
print(f"[Custom Model] 載入 {len(custom_labels)} 個類別標籤")
except Exception as e:
print(f"[Custom Model] 讀取 labels 檔案失敗: {e}")
# 創建自訂模型配置
custom_config = {
"mode": "custom",
"model_name": "user_uploaded",
"display_name": model_name,
"description": "User uploaded custom model",
"is_custom": True,
"custom_model_path": model_path,
"custom_scpu_path": scpu_path,
"custom_ncpu_path": ncpu_path,
"custom_labels_path": labels_path,
"custom_labels": custom_labels, # 類別標籤列表
"input_info": {
"type": "video" # 預設使用視訊模式
},
"input_parameters": {}
}
# 發送信號或直接調用控制器
if hasattr(inference_controller, 'select_custom_tool'):
inference_controller.select_custom_tool(custom_config)
else:
# 如果沒有專用方法,使用通用的 select_tool
inference_controller.select_tool(custom_config)
# 連接按鈕事件
stop_btn.clicked.connect(stop_inference_and_disconnect)
run_btn.clicked.connect(run_inference)
return block_frame