forked from masonhuang/KNEO-Academy
- custom_inference_worker: reuse existing device_group from DeviceController to avoid double kp.connect_devices() conflict on same USB port - custom_inference_worker: add TYPE_CHECKING guard for kp type annotations to prevent potential NameError at import time - utilities_screen: replace missing back_arrow.png with text arrow (←) - utilities_screen: add set_device_controller() so AppController can inject MainWindow's shared DeviceController instance - main.py: wire UtilitiesScreen to share MainWindow's DeviceController - video_thread: emit camera_error_signal on failure and max-retry exhaustion - media_controller: connect camera_error_signal and display error on canvas - media_panel: fix pause button using wrong delete icon; use video_normal SVG Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
423 lines
18 KiB
Python
423 lines
18 KiB
Python
"""
|
|
inference_controller.py - Inference Controller
|
|
|
|
This module handles AI model inference operations including tool selection,
|
|
model loading, and frame processing for the Kneron AI Playground.
|
|
"""
|
|
|
|
import os
|
|
import queue
|
|
import cv2
|
|
import json
|
|
from PyQt5.QtWidgets import QMessageBox, QApplication
|
|
from PyQt5.QtCore import QTimer, Qt
|
|
|
|
from src.models.inference_worker import InferenceWorkerThread
|
|
from src.models.custom_inference_worker import CustomInferenceWorkerThread
|
|
from src.config import UTILS_DIR, FW_DIR, DongleModelMap
|
|
|
|
|
|
class InferenceController:
|
|
"""
|
|
Controller class for managing AI model inference operations.
|
|
|
|
Attributes:
|
|
main_window: Reference to the main application window
|
|
device_controller: Reference to the device controller
|
|
inference_worker: Thread worker for running inference
|
|
inference_queue: Queue for frames to be processed
|
|
current_tool_config: Current AI tool configuration
|
|
model_descriptor: Loaded model descriptor
|
|
"""
|
|
|
|
def __init__(self, main_window, device_controller):
|
|
"""
|
|
Initialize the InferenceController.
|
|
|
|
Args:
|
|
main_window: Reference to the main application window.
|
|
device_controller: Reference to the device controller.
|
|
"""
|
|
self.main_window = main_window
|
|
self.device_controller = device_controller
|
|
self.inference_worker = None
|
|
self.inference_queue = queue.Queue(maxsize=5)
|
|
self.current_tool_config = None
|
|
self.previous_tool_config = None
|
|
self._camera_was_active = False
|
|
# Store original frame dimensions for bounding box scaling
|
|
self.original_frame_width = 640 # Default value
|
|
self.original_frame_height = 480 # Default value
|
|
self.model_descriptor = None
|
|
|
|
def select_tool(self, tool_config):
|
|
"""
|
|
Select an AI tool and configure inference.
|
|
|
|
Args:
|
|
tool_config (dict): Configuration dictionary for the AI tool.
|
|
|
|
Returns:
|
|
bool: True if tool selection successful, False otherwise.
|
|
"""
|
|
try:
|
|
print("Selecting tool:", tool_config.get("display_name"))
|
|
self.current_tool_config = tool_config
|
|
|
|
# Get mode and model name
|
|
mode = tool_config.get("mode", "")
|
|
model_name = tool_config.get("model_name", "")
|
|
|
|
# Load detailed model configuration
|
|
model_path = os.path.join(UTILS_DIR, mode, model_name)
|
|
model_config_path = os.path.join(model_path, "config.json")
|
|
|
|
if os.path.exists(model_config_path):
|
|
try:
|
|
with open(model_config_path, "r", encoding="utf-8") as f:
|
|
detailed_config = json.load(f)
|
|
tool_config = {**tool_config, **detailed_config}
|
|
except Exception as e:
|
|
print(f"Error reading model config: {e}")
|
|
|
|
# Get tool input type
|
|
input_info = tool_config.get("input_info", {})
|
|
tool_type = input_info.get("type", "video")
|
|
once_mode = True if tool_type == "image" else False
|
|
|
|
# Check if switching from video mode to image mode or vice versa
|
|
previous_tool_type = "video"
|
|
if hasattr(self, 'previous_tool_config') and self.previous_tool_config:
|
|
previous_input_info = self.previous_tool_config.get("input_info", {})
|
|
previous_tool_type = previous_input_info.get("type", "video")
|
|
|
|
# Clear inference queue to avoid using old data when switching modes
|
|
self._clear_inference_queue()
|
|
|
|
# Store current tool type for next comparison
|
|
self.previous_tool_config = tool_config
|
|
|
|
# Prepare input parameters
|
|
input_params = tool_config.get("input_parameters", {}).copy()
|
|
|
|
# Get connected device group
|
|
device_group = self.device_controller.get_device_group()
|
|
# Add device group to input parameters
|
|
input_params["device_group"] = device_group
|
|
|
|
# Configure device-related settings
|
|
selected_device = self.device_controller.get_selected_device()
|
|
if selected_device:
|
|
# Get usb_port_id (check if it's a dictionary or object)
|
|
if isinstance(selected_device, dict):
|
|
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
|
|
product_id = selected_device.get("product_id", "unknown")
|
|
else:
|
|
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
|
|
product_id = getattr(selected_device, "product_id", "unknown")
|
|
|
|
# Ensure product_id is in the right format for lookup
|
|
# Convert to lowercase hex string if it's a number
|
|
if isinstance(product_id, int):
|
|
product_id = hex(product_id).lower()
|
|
# If it's a string but doesn't start with '0x', add it
|
|
elif isinstance(product_id, str) and not product_id.startswith('0x'):
|
|
try:
|
|
# Try to convert to int first, then to hex format
|
|
product_id = hex(int(product_id, 0)).lower()
|
|
except ValueError:
|
|
# If conversion fails, keep as is
|
|
pass
|
|
|
|
# Map product_id to dongle type/series
|
|
dongle = DongleModelMap.get(product_id, "unknown")
|
|
print(f"Selected device: product_id={product_id}, mapped to={dongle}")
|
|
|
|
# Verify device compatibility
|
|
compatible_devices = tool_config.get("compatible_devices", [])
|
|
if compatible_devices and dongle not in compatible_devices:
|
|
msgBox = QMessageBox(self.main_window)
|
|
msgBox.setIcon(QMessageBox.Warning)
|
|
msgBox.setWindowTitle("Device Incompatible")
|
|
msgBox.setText(f"The selected model does not support {dongle} device.\nSupported devices: {', '.join(compatible_devices)}")
|
|
msgBox.setStyleSheet("QLabel { color: white; } QMessageBox { background-color: #2b2b2b; }")
|
|
msgBox.exec_()
|
|
return False
|
|
|
|
# Add firmware paths as reference only (device already connected during selection)
|
|
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
|
|
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
|
|
input_params["scpu_path"] = scpu_path
|
|
input_params["ncpu_path"] = ncpu_path
|
|
else:
|
|
# Default device handling
|
|
devices = self.device_controller.connected_devices
|
|
if devices and len(devices) > 0:
|
|
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
|
|
print("Warning: No device specifically selected, using first available device")
|
|
else:
|
|
input_params["usb_port_id"] = 0
|
|
print("Warning: No connected devices, using default usb_port_id 0")
|
|
|
|
# Handle file inputs for image/voice modes
|
|
if tool_type in ["image", "voice"]:
|
|
if hasattr(self.main_window, "destination") and self.main_window.destination:
|
|
input_params["file_path"] = self.main_window.destination
|
|
if tool_type == "image":
|
|
uploaded_img = cv2.imread(self.main_window.destination)
|
|
if uploaded_img is not None:
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(uploaded_img)
|
|
print("Uploaded image added to inference queue")
|
|
else:
|
|
print("Warning: inference queue is full")
|
|
else:
|
|
print("Warning: Unable to read uploaded image")
|
|
else:
|
|
input_params["file_path"] = ""
|
|
print(f"Warning: {tool_type} mode requires a file input, but no file has been uploaded.")
|
|
|
|
# Add model file path
|
|
if "model_file" in tool_config:
|
|
model_file = tool_config["model_file"]
|
|
model_file_path = os.path.join(model_path, model_file)
|
|
input_params["model"] = model_file_path
|
|
|
|
# Upload model to device
|
|
if device_group:
|
|
try:
|
|
import kp
|
|
print('[Uploading model]')
|
|
self.model_descriptor = kp.core.load_model_from_file(
|
|
device_group=device_group,
|
|
file_path=model_file_path
|
|
)
|
|
print(' - Upload successful')
|
|
|
|
# Add model descriptor to input parameters
|
|
input_params["model_descriptor"] = self.model_descriptor
|
|
except Exception as e:
|
|
print(f"Error uploading model: {e}")
|
|
self.model_descriptor = None
|
|
|
|
print("Input parameters:", input_params)
|
|
|
|
# Stop existing inference worker if running
|
|
if self.inference_worker:
|
|
self.inference_worker.stop()
|
|
self.inference_worker = None
|
|
|
|
# Create new inference worker
|
|
self.inference_worker = InferenceWorkerThread(
|
|
self.inference_queue,
|
|
mode,
|
|
model_name,
|
|
min_interval=2,
|
|
mse_threshold=500,
|
|
once_mode=once_mode
|
|
)
|
|
self.inference_worker.input_params = input_params
|
|
self.inference_worker.inference_result_signal.connect(self.main_window.handle_inference_result)
|
|
self.inference_worker.start()
|
|
print(f"Inference worker started for module: {mode}/{model_name}")
|
|
|
|
# Start camera if needed
|
|
if tool_type == "video":
|
|
# If camera was previously active but disconnected for image processing
|
|
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
|
|
# Reconnect the signal
|
|
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
|
|
self.main_window.media_controller.update_image
|
|
)
|
|
print("Camera reconnected for video processing")
|
|
else:
|
|
# Start camera normally
|
|
self.main_window.media_controller.start_camera()
|
|
else:
|
|
# For image tools, temporarily pause the camera but don't stop it completely
|
|
# This allows switching back to video tools without restarting the camera
|
|
if self.main_window.media_controller.video_thread is not None:
|
|
# Save current state to indicate camera was running
|
|
self._camera_was_active = True
|
|
# Disconnect signal to prevent processing frames during image inference
|
|
self.main_window.media_controller.video_thread.change_pixmap_signal.disconnect()
|
|
print("Camera paused for image processing")
|
|
else:
|
|
self._camera_was_active = False
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"Error selecting tool: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
return False
|
|
|
|
def _clear_inference_queue(self):
|
|
"""Clear all data from the inference queue."""
|
|
try:
|
|
# Clear existing queue
|
|
while not self.inference_queue.empty():
|
|
try:
|
|
self.inference_queue.get_nowait()
|
|
except queue.Empty:
|
|
break
|
|
print("Inference queue cleared")
|
|
except Exception as e:
|
|
print(f"Error clearing inference queue: {e}")
|
|
|
|
def add_frame_to_queue(self, frame):
|
|
"""
|
|
Add a frame to the inference queue.
|
|
|
|
Args:
|
|
frame: The image frame to add (numpy array).
|
|
"""
|
|
try:
|
|
# Update original frame dimensions
|
|
if frame is not None and hasattr(frame, 'shape'):
|
|
height, width = frame.shape[:2]
|
|
self.original_frame_width = width
|
|
self.original_frame_height = height
|
|
|
|
# Add to queue
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(frame)
|
|
except Exception as e:
|
|
print(f"Error adding frame to queue: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
|
|
def stop_inference(self):
|
|
"""Stop the inference worker thread."""
|
|
if self.inference_worker:
|
|
self.inference_worker.stop()
|
|
self.inference_worker = None
|
|
|
|
def process_uploaded_image(self, file_path):
|
|
"""
|
|
Process an uploaded image and run inference.
|
|
|
|
Args:
|
|
file_path (str): Path to the uploaded image file.
|
|
"""
|
|
try:
|
|
if not os.path.exists(file_path):
|
|
print(f"Error: File does not exist {file_path}")
|
|
return
|
|
|
|
# Clear inference queue to ensure only the latest image is processed
|
|
self._clear_inference_queue()
|
|
|
|
# Read image
|
|
img = cv2.imread(file_path)
|
|
if img is None:
|
|
print(f"Error: Unable to read image {file_path}")
|
|
return
|
|
|
|
# Update inference worker parameters
|
|
if self.inference_worker:
|
|
self.inference_worker.input_params["file_path"] = file_path
|
|
|
|
# Add image to inference queue
|
|
if not self.inference_queue.full():
|
|
self.inference_queue.put(img)
|
|
print(f"Added image {file_path} to inference queue")
|
|
else:
|
|
print("Warning: Inference queue is full")
|
|
else:
|
|
print("Error: Inference worker not initialized")
|
|
except Exception as e:
|
|
print(f"Error processing uploaded image: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
|
|
def select_custom_tool(self, tool_config):
|
|
"""
|
|
Select a custom model tool and configure inference.
|
|
|
|
Args:
|
|
tool_config (dict): Configuration dictionary containing:
|
|
- custom_model_path: Path to .nef model file
|
|
- custom_scpu_path: Path to SCPU firmware
|
|
- custom_ncpu_path: Path to NCPU firmware
|
|
- custom_labels: Optional list of class labels
|
|
|
|
Returns:
|
|
bool: True if custom tool selection successful, False otherwise.
|
|
"""
|
|
try:
|
|
print("Selecting custom model:", tool_config.get("display_name"))
|
|
self.current_tool_config = tool_config
|
|
|
|
# Clear inference queue
|
|
self._clear_inference_queue()
|
|
|
|
# Store current tool type for next comparison
|
|
self.previous_tool_config = tool_config
|
|
|
|
# Prepare input parameters
|
|
input_params = tool_config.get("input_parameters", {}).copy()
|
|
|
|
# Add custom model paths
|
|
input_params["custom_model_path"] = tool_config.get("custom_model_path")
|
|
input_params["custom_scpu_path"] = tool_config.get("custom_scpu_path")
|
|
input_params["custom_ncpu_path"] = tool_config.get("custom_ncpu_path")
|
|
input_params["custom_labels"] = tool_config.get("custom_labels")
|
|
|
|
# Pass existing device_group to avoid double connection in the worker
|
|
input_params["device_group"] = self.device_controller.get_device_group()
|
|
|
|
# Get device-related settings
|
|
selected_device = self.device_controller.get_selected_device()
|
|
if selected_device:
|
|
if isinstance(selected_device, dict):
|
|
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
|
|
else:
|
|
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
|
|
else:
|
|
devices = self.device_controller.connected_devices
|
|
if devices and len(devices) > 0:
|
|
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
|
|
print("Warning: No device selected, using first available device")
|
|
else:
|
|
input_params["usb_port_id"] = 0
|
|
print("Warning: No connected devices, using default usb_port_id 0")
|
|
|
|
print("Custom model input parameters:", input_params)
|
|
|
|
# Stop existing inference worker
|
|
if self.inference_worker:
|
|
self.inference_worker.stop()
|
|
self.inference_worker = None
|
|
|
|
# Create new custom inference worker
|
|
self.inference_worker = CustomInferenceWorkerThread(
|
|
self.inference_queue,
|
|
min_interval=0.5,
|
|
mse_threshold=500
|
|
)
|
|
self.inference_worker.input_params = input_params
|
|
self.inference_worker.inference_result_signal.connect(
|
|
self.main_window.handle_inference_result
|
|
)
|
|
self.inference_worker.start()
|
|
print("Custom model inference worker started")
|
|
|
|
# Start camera (custom model defaults to video mode)
|
|
tool_type = tool_config.get("input_info", {}).get("type", "video")
|
|
if tool_type == "video":
|
|
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
|
|
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
|
|
self.main_window.media_controller.update_image
|
|
)
|
|
print("Camera reconnected for video processing")
|
|
else:
|
|
self.main_window.media_controller.start_camera()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"Error selecting custom model: {e}")
|
|
import traceback
|
|
print(traceback.format_exc())
|
|
return False |