KNEO-Academy/src/controllers/inference_controller.py
abin 7e323cf3e1 Fix: resolve 5 bugs found during project onboarding health check
- custom_inference_worker: reuse existing device_group from DeviceController
  to avoid double kp.connect_devices() conflict on same USB port
- custom_inference_worker: add TYPE_CHECKING guard for kp type annotations
  to prevent potential NameError at import time
- utilities_screen: replace missing back_arrow.png with text arrow (←)
- utilities_screen: add set_device_controller() so AppController can inject
  MainWindow's shared DeviceController instance
- main.py: wire UtilitiesScreen to share MainWindow's DeviceController
- video_thread: emit camera_error_signal on failure and max-retry exhaustion
- media_controller: connect camera_error_signal and display error on canvas
- media_panel: fix pause button using wrong delete icon; use video_normal SVG

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:33:37 +08:00

423 lines
18 KiB
Python

"""
inference_controller.py - Inference Controller
This module handles AI model inference operations including tool selection,
model loading, and frame processing for the Kneron AI Playground.
"""
import os
import queue
import cv2
import json
from PyQt5.QtWidgets import QMessageBox, QApplication
from PyQt5.QtCore import QTimer, Qt
from src.models.inference_worker import InferenceWorkerThread
from src.models.custom_inference_worker import CustomInferenceWorkerThread
from src.config import UTILS_DIR, FW_DIR, DongleModelMap
class InferenceController:
"""
Controller class for managing AI model inference operations.
Attributes:
main_window: Reference to the main application window
device_controller: Reference to the device controller
inference_worker: Thread worker for running inference
inference_queue: Queue for frames to be processed
current_tool_config: Current AI tool configuration
model_descriptor: Loaded model descriptor
"""
def __init__(self, main_window, device_controller):
"""
Initialize the InferenceController.
Args:
main_window: Reference to the main application window.
device_controller: Reference to the device controller.
"""
self.main_window = main_window
self.device_controller = device_controller
self.inference_worker = None
self.inference_queue = queue.Queue(maxsize=5)
self.current_tool_config = None
self.previous_tool_config = None
self._camera_was_active = False
# Store original frame dimensions for bounding box scaling
self.original_frame_width = 640 # Default value
self.original_frame_height = 480 # Default value
self.model_descriptor = None
def select_tool(self, tool_config):
"""
Select an AI tool and configure inference.
Args:
tool_config (dict): Configuration dictionary for the AI tool.
Returns:
bool: True if tool selection successful, False otherwise.
"""
try:
print("Selecting tool:", tool_config.get("display_name"))
self.current_tool_config = tool_config
# Get mode and model name
mode = tool_config.get("mode", "")
model_name = tool_config.get("model_name", "")
# Load detailed model configuration
model_path = os.path.join(UTILS_DIR, mode, model_name)
model_config_path = os.path.join(model_path, "config.json")
if os.path.exists(model_config_path):
try:
with open(model_config_path, "r", encoding="utf-8") as f:
detailed_config = json.load(f)
tool_config = {**tool_config, **detailed_config}
except Exception as e:
print(f"Error reading model config: {e}")
# Get tool input type
input_info = tool_config.get("input_info", {})
tool_type = input_info.get("type", "video")
once_mode = True if tool_type == "image" else False
# Check if switching from video mode to image mode or vice versa
previous_tool_type = "video"
if hasattr(self, 'previous_tool_config') and self.previous_tool_config:
previous_input_info = self.previous_tool_config.get("input_info", {})
previous_tool_type = previous_input_info.get("type", "video")
# Clear inference queue to avoid using old data when switching modes
self._clear_inference_queue()
# Store current tool type for next comparison
self.previous_tool_config = tool_config
# Prepare input parameters
input_params = tool_config.get("input_parameters", {}).copy()
# Get connected device group
device_group = self.device_controller.get_device_group()
# Add device group to input parameters
input_params["device_group"] = device_group
# Configure device-related settings
selected_device = self.device_controller.get_selected_device()
if selected_device:
# Get usb_port_id (check if it's a dictionary or object)
if isinstance(selected_device, dict):
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
product_id = selected_device.get("product_id", "unknown")
else:
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
product_id = getattr(selected_device, "product_id", "unknown")
# Ensure product_id is in the right format for lookup
# Convert to lowercase hex string if it's a number
if isinstance(product_id, int):
product_id = hex(product_id).lower()
# If it's a string but doesn't start with '0x', add it
elif isinstance(product_id, str) and not product_id.startswith('0x'):
try:
# Try to convert to int first, then to hex format
product_id = hex(int(product_id, 0)).lower()
except ValueError:
# If conversion fails, keep as is
pass
# Map product_id to dongle type/series
dongle = DongleModelMap.get(product_id, "unknown")
print(f"Selected device: product_id={product_id}, mapped to={dongle}")
# Verify device compatibility
compatible_devices = tool_config.get("compatible_devices", [])
if compatible_devices and dongle not in compatible_devices:
msgBox = QMessageBox(self.main_window)
msgBox.setIcon(QMessageBox.Warning)
msgBox.setWindowTitle("Device Incompatible")
msgBox.setText(f"The selected model does not support {dongle} device.\nSupported devices: {', '.join(compatible_devices)}")
msgBox.setStyleSheet("QLabel { color: white; } QMessageBox { background-color: #2b2b2b; }")
msgBox.exec_()
return False
# Add firmware paths as reference only (device already connected during selection)
scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin")
ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin")
input_params["scpu_path"] = scpu_path
input_params["ncpu_path"] = ncpu_path
else:
# Default device handling
devices = self.device_controller.connected_devices
if devices and len(devices) > 0:
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
print("Warning: No device specifically selected, using first available device")
else:
input_params["usb_port_id"] = 0
print("Warning: No connected devices, using default usb_port_id 0")
# Handle file inputs for image/voice modes
if tool_type in ["image", "voice"]:
if hasattr(self.main_window, "destination") and self.main_window.destination:
input_params["file_path"] = self.main_window.destination
if tool_type == "image":
uploaded_img = cv2.imread(self.main_window.destination)
if uploaded_img is not None:
if not self.inference_queue.full():
self.inference_queue.put(uploaded_img)
print("Uploaded image added to inference queue")
else:
print("Warning: inference queue is full")
else:
print("Warning: Unable to read uploaded image")
else:
input_params["file_path"] = ""
print(f"Warning: {tool_type} mode requires a file input, but no file has been uploaded.")
# Add model file path
if "model_file" in tool_config:
model_file = tool_config["model_file"]
model_file_path = os.path.join(model_path, model_file)
input_params["model"] = model_file_path
# Upload model to device
if device_group:
try:
import kp
print('[Uploading model]')
self.model_descriptor = kp.core.load_model_from_file(
device_group=device_group,
file_path=model_file_path
)
print(' - Upload successful')
# Add model descriptor to input parameters
input_params["model_descriptor"] = self.model_descriptor
except Exception as e:
print(f"Error uploading model: {e}")
self.model_descriptor = None
print("Input parameters:", input_params)
# Stop existing inference worker if running
if self.inference_worker:
self.inference_worker.stop()
self.inference_worker = None
# Create new inference worker
self.inference_worker = InferenceWorkerThread(
self.inference_queue,
mode,
model_name,
min_interval=2,
mse_threshold=500,
once_mode=once_mode
)
self.inference_worker.input_params = input_params
self.inference_worker.inference_result_signal.connect(self.main_window.handle_inference_result)
self.inference_worker.start()
print(f"Inference worker started for module: {mode}/{model_name}")
# Start camera if needed
if tool_type == "video":
# If camera was previously active but disconnected for image processing
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
# Reconnect the signal
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
self.main_window.media_controller.update_image
)
print("Camera reconnected for video processing")
else:
# Start camera normally
self.main_window.media_controller.start_camera()
else:
# For image tools, temporarily pause the camera but don't stop it completely
# This allows switching back to video tools without restarting the camera
if self.main_window.media_controller.video_thread is not None:
# Save current state to indicate camera was running
self._camera_was_active = True
# Disconnect signal to prevent processing frames during image inference
self.main_window.media_controller.video_thread.change_pixmap_signal.disconnect()
print("Camera paused for image processing")
else:
self._camera_was_active = False
return True
except Exception as e:
print(f"Error selecting tool: {e}")
import traceback
print(traceback.format_exc())
return False
def _clear_inference_queue(self):
"""Clear all data from the inference queue."""
try:
# Clear existing queue
while not self.inference_queue.empty():
try:
self.inference_queue.get_nowait()
except queue.Empty:
break
print("Inference queue cleared")
except Exception as e:
print(f"Error clearing inference queue: {e}")
def add_frame_to_queue(self, frame):
"""
Add a frame to the inference queue.
Args:
frame: The image frame to add (numpy array).
"""
try:
# Update original frame dimensions
if frame is not None and hasattr(frame, 'shape'):
height, width = frame.shape[:2]
self.original_frame_width = width
self.original_frame_height = height
# Add to queue
if not self.inference_queue.full():
self.inference_queue.put(frame)
except Exception as e:
print(f"Error adding frame to queue: {e}")
import traceback
print(traceback.format_exc())
def stop_inference(self):
"""Stop the inference worker thread."""
if self.inference_worker:
self.inference_worker.stop()
self.inference_worker = None
def process_uploaded_image(self, file_path):
"""
Process an uploaded image and run inference.
Args:
file_path (str): Path to the uploaded image file.
"""
try:
if not os.path.exists(file_path):
print(f"Error: File does not exist {file_path}")
return
# Clear inference queue to ensure only the latest image is processed
self._clear_inference_queue()
# Read image
img = cv2.imread(file_path)
if img is None:
print(f"Error: Unable to read image {file_path}")
return
# Update inference worker parameters
if self.inference_worker:
self.inference_worker.input_params["file_path"] = file_path
# Add image to inference queue
if not self.inference_queue.full():
self.inference_queue.put(img)
print(f"Added image {file_path} to inference queue")
else:
print("Warning: Inference queue is full")
else:
print("Error: Inference worker not initialized")
except Exception as e:
print(f"Error processing uploaded image: {e}")
import traceback
print(traceback.format_exc())
def select_custom_tool(self, tool_config):
"""
Select a custom model tool and configure inference.
Args:
tool_config (dict): Configuration dictionary containing:
- custom_model_path: Path to .nef model file
- custom_scpu_path: Path to SCPU firmware
- custom_ncpu_path: Path to NCPU firmware
- custom_labels: Optional list of class labels
Returns:
bool: True if custom tool selection successful, False otherwise.
"""
try:
print("Selecting custom model:", tool_config.get("display_name"))
self.current_tool_config = tool_config
# Clear inference queue
self._clear_inference_queue()
# Store current tool type for next comparison
self.previous_tool_config = tool_config
# Prepare input parameters
input_params = tool_config.get("input_parameters", {}).copy()
# Add custom model paths
input_params["custom_model_path"] = tool_config.get("custom_model_path")
input_params["custom_scpu_path"] = tool_config.get("custom_scpu_path")
input_params["custom_ncpu_path"] = tool_config.get("custom_ncpu_path")
input_params["custom_labels"] = tool_config.get("custom_labels")
# Pass existing device_group to avoid double connection in the worker
input_params["device_group"] = self.device_controller.get_device_group()
# Get device-related settings
selected_device = self.device_controller.get_selected_device()
if selected_device:
if isinstance(selected_device, dict):
input_params["usb_port_id"] = selected_device.get("usb_port_id", 0)
else:
input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0)
else:
devices = self.device_controller.connected_devices
if devices and len(devices) > 0:
input_params["usb_port_id"] = devices[0].get("usb_port_id", 0)
print("Warning: No device selected, using first available device")
else:
input_params["usb_port_id"] = 0
print("Warning: No connected devices, using default usb_port_id 0")
print("Custom model input parameters:", input_params)
# Stop existing inference worker
if self.inference_worker:
self.inference_worker.stop()
self.inference_worker = None
# Create new custom inference worker
self.inference_worker = CustomInferenceWorkerThread(
self.inference_queue,
min_interval=0.5,
mse_threshold=500
)
self.inference_worker.input_params = input_params
self.inference_worker.inference_result_signal.connect(
self.main_window.handle_inference_result
)
self.inference_worker.start()
print("Custom model inference worker started")
# Start camera (custom model defaults to video mode)
tool_type = tool_config.get("input_info", {}).get("type", "video")
if tool_type == "video":
if self._camera_was_active and self.main_window.media_controller.video_thread is not None:
self.main_window.media_controller.video_thread.change_pixmap_signal.connect(
self.main_window.media_controller.update_image
)
print("Camera reconnected for video processing")
else:
self.main_window.media_controller.start_camera()
return True
except Exception as e:
print(f"Error selecting custom model: {e}")
import traceback
print(traceback.format_exc())
return False