""" inference_controller.py - Inference Controller This module handles AI model inference operations including tool selection, model loading, and frame processing for the Kneron AI Playground. """ import os import queue import cv2 import json from PyQt5.QtWidgets import QMessageBox, QApplication from PyQt5.QtCore import QTimer, Qt from src.models.inference_worker import InferenceWorkerThread from src.models.custom_inference_worker import CustomInferenceWorkerThread from src.config import UTILS_DIR, FW_DIR, DongleModelMap class InferenceController: """ Controller class for managing AI model inference operations. Attributes: main_window: Reference to the main application window device_controller: Reference to the device controller inference_worker: Thread worker for running inference inference_queue: Queue for frames to be processed current_tool_config: Current AI tool configuration model_descriptor: Loaded model descriptor """ def __init__(self, main_window, device_controller): """ Initialize the InferenceController. Args: main_window: Reference to the main application window. device_controller: Reference to the device controller. """ self.main_window = main_window self.device_controller = device_controller self.inference_worker = None self.inference_queue = queue.Queue(maxsize=5) self.current_tool_config = None self.previous_tool_config = None self._camera_was_active = False # Store original frame dimensions for bounding box scaling self.original_frame_width = 640 # Default value self.original_frame_height = 480 # Default value self.model_descriptor = None def select_tool(self, tool_config): """ Select an AI tool and configure inference. Args: tool_config (dict): Configuration dictionary for the AI tool. Returns: bool: True if tool selection successful, False otherwise. """ try: print("Selecting tool:", tool_config.get("display_name")) self.current_tool_config = tool_config # Get mode and model name mode = tool_config.get("mode", "") model_name = tool_config.get("model_name", "") # Load detailed model configuration model_path = os.path.join(UTILS_DIR, mode, model_name) model_config_path = os.path.join(model_path, "config.json") if os.path.exists(model_config_path): try: with open(model_config_path, "r", encoding="utf-8") as f: detailed_config = json.load(f) tool_config = {**tool_config, **detailed_config} except Exception as e: print(f"Error reading model config: {e}") # Get tool input type input_info = tool_config.get("input_info", {}) tool_type = input_info.get("type", "video") once_mode = True if tool_type == "image" else False # Check if switching from video mode to image mode or vice versa previous_tool_type = "video" if hasattr(self, 'previous_tool_config') and self.previous_tool_config: previous_input_info = self.previous_tool_config.get("input_info", {}) previous_tool_type = previous_input_info.get("type", "video") # Clear inference queue to avoid using old data when switching modes self._clear_inference_queue() # Store current tool type for next comparison self.previous_tool_config = tool_config # Prepare input parameters input_params = tool_config.get("input_parameters", {}).copy() # Get connected device group device_group = self.device_controller.get_device_group() # Add device group to input parameters input_params["device_group"] = device_group # Configure device-related settings selected_device = self.device_controller.get_selected_device() if selected_device: # Get usb_port_id (check if it's a dictionary or object) if isinstance(selected_device, dict): input_params["usb_port_id"] = selected_device.get("usb_port_id", 0) product_id = selected_device.get("product_id", "unknown") else: input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0) product_id = getattr(selected_device, "product_id", "unknown") # Ensure product_id is in the right format for lookup # Convert to lowercase hex string if it's a number if isinstance(product_id, int): product_id = hex(product_id).lower() # If it's a string but doesn't start with '0x', add it elif isinstance(product_id, str) and not product_id.startswith('0x'): try: # Try to convert to int first, then to hex format product_id = hex(int(product_id, 0)).lower() except ValueError: # If conversion fails, keep as is pass # Map product_id to dongle type/series dongle = DongleModelMap.get(product_id, "unknown") print(f"Selected device: product_id={product_id}, mapped to={dongle}") # Verify device compatibility compatible_devices = tool_config.get("compatible_devices", []) if compatible_devices and dongle not in compatible_devices: msgBox = QMessageBox(self.main_window) msgBox.setIcon(QMessageBox.Warning) msgBox.setWindowTitle("Device Incompatible") msgBox.setText(f"The selected model does not support {dongle} device.\nSupported devices: {', '.join(compatible_devices)}") msgBox.setStyleSheet("QLabel { color: white; } QMessageBox { background-color: #2b2b2b; }") msgBox.exec_() return False # Add firmware paths as reference only (device already connected during selection) scpu_path = os.path.join(FW_DIR, dongle, "fw_scpu.bin") ncpu_path = os.path.join(FW_DIR, dongle, "fw_ncpu.bin") input_params["scpu_path"] = scpu_path input_params["ncpu_path"] = ncpu_path else: # Default device handling devices = self.device_controller.connected_devices if devices and len(devices) > 0: input_params["usb_port_id"] = devices[0].get("usb_port_id", 0) print("Warning: No device specifically selected, using first available device") else: input_params["usb_port_id"] = 0 print("Warning: No connected devices, using default usb_port_id 0") # Handle file inputs for image/voice modes if tool_type in ["image", "voice"]: if hasattr(self.main_window, "destination") and self.main_window.destination: input_params["file_path"] = self.main_window.destination if tool_type == "image": uploaded_img = cv2.imread(self.main_window.destination) if uploaded_img is not None: if not self.inference_queue.full(): self.inference_queue.put(uploaded_img) print("Uploaded image added to inference queue") else: print("Warning: inference queue is full") else: print("Warning: Unable to read uploaded image") else: input_params["file_path"] = "" print(f"Warning: {tool_type} mode requires a file input, but no file has been uploaded.") # Add model file path if "model_file" in tool_config: model_file = tool_config["model_file"] model_file_path = os.path.join(model_path, model_file) input_params["model"] = model_file_path # Upload model to device if device_group: try: import kp print('[Uploading model]') self.model_descriptor = kp.core.load_model_from_file( device_group=device_group, file_path=model_file_path ) print(' - Upload successful') # Add model descriptor to input parameters input_params["model_descriptor"] = self.model_descriptor except Exception as e: print(f"Error uploading model: {e}") self.model_descriptor = None print("Input parameters:", input_params) # Stop existing inference worker if running if self.inference_worker: self.inference_worker.stop() self.inference_worker = None # Create new inference worker self.inference_worker = InferenceWorkerThread( self.inference_queue, mode, model_name, min_interval=2, mse_threshold=500, once_mode=once_mode ) self.inference_worker.input_params = input_params self.inference_worker.inference_result_signal.connect(self.main_window.handle_inference_result) self.inference_worker.start() print(f"Inference worker started for module: {mode}/{model_name}") # Start camera if needed if tool_type == "video": # If camera was previously active but disconnected for image processing if self._camera_was_active and self.main_window.media_controller.video_thread is not None: # Reconnect the signal self.main_window.media_controller.video_thread.change_pixmap_signal.connect( self.main_window.media_controller.update_image ) print("Camera reconnected for video processing") else: # Start camera normally self.main_window.media_controller.start_camera() else: # For image tools, temporarily pause the camera but don't stop it completely # This allows switching back to video tools without restarting the camera if self.main_window.media_controller.video_thread is not None: # Save current state to indicate camera was running self._camera_was_active = True # Disconnect signal to prevent processing frames during image inference self.main_window.media_controller.video_thread.change_pixmap_signal.disconnect() print("Camera paused for image processing") else: self._camera_was_active = False return True except Exception as e: print(f"Error selecting tool: {e}") import traceback print(traceback.format_exc()) return False def _clear_inference_queue(self): """Clear all data from the inference queue.""" try: # Clear existing queue while not self.inference_queue.empty(): try: self.inference_queue.get_nowait() except queue.Empty: break print("Inference queue cleared") except Exception as e: print(f"Error clearing inference queue: {e}") def add_frame_to_queue(self, frame): """ Add a frame to the inference queue. Args: frame: The image frame to add (numpy array). """ try: # Update original frame dimensions if frame is not None and hasattr(frame, 'shape'): height, width = frame.shape[:2] self.original_frame_width = width self.original_frame_height = height # Add to queue if not self.inference_queue.full(): self.inference_queue.put(frame) except Exception as e: print(f"Error adding frame to queue: {e}") import traceback print(traceback.format_exc()) def stop_inference(self): """Stop the inference worker thread.""" if self.inference_worker: self.inference_worker.stop() self.inference_worker = None def process_uploaded_image(self, file_path): """ Process an uploaded image and run inference. Args: file_path (str): Path to the uploaded image file. """ try: if not os.path.exists(file_path): print(f"Error: File does not exist {file_path}") return # Clear inference queue to ensure only the latest image is processed self._clear_inference_queue() # Read image img = cv2.imread(file_path) if img is None: print(f"Error: Unable to read image {file_path}") return # Update inference worker parameters if self.inference_worker: self.inference_worker.input_params["file_path"] = file_path # Add image to inference queue if not self.inference_queue.full(): self.inference_queue.put(img) print(f"Added image {file_path} to inference queue") else: print("Warning: Inference queue is full") else: print("Error: Inference worker not initialized") except Exception as e: print(f"Error processing uploaded image: {e}") import traceback print(traceback.format_exc()) def select_custom_tool(self, tool_config): """ Select a custom model tool and configure inference. Args: tool_config (dict): Configuration dictionary containing: - custom_model_path: Path to .nef model file - custom_scpu_path: Path to SCPU firmware - custom_ncpu_path: Path to NCPU firmware - custom_labels: Optional list of class labels Returns: bool: True if custom tool selection successful, False otherwise. """ try: print("Selecting custom model:", tool_config.get("display_name")) self.current_tool_config = tool_config # Clear inference queue self._clear_inference_queue() # Store current tool type for next comparison self.previous_tool_config = tool_config # Prepare input parameters input_params = tool_config.get("input_parameters", {}).copy() # Add custom model paths input_params["custom_model_path"] = tool_config.get("custom_model_path") input_params["custom_scpu_path"] = tool_config.get("custom_scpu_path") input_params["custom_ncpu_path"] = tool_config.get("custom_ncpu_path") input_params["custom_labels"] = tool_config.get("custom_labels") # Pass existing device_group to avoid double connection in the worker input_params["device_group"] = self.device_controller.get_device_group() # Get device-related settings selected_device = self.device_controller.get_selected_device() if selected_device: if isinstance(selected_device, dict): input_params["usb_port_id"] = selected_device.get("usb_port_id", 0) else: input_params["usb_port_id"] = getattr(selected_device, "usb_port_id", 0) else: devices = self.device_controller.connected_devices if devices and len(devices) > 0: input_params["usb_port_id"] = devices[0].get("usb_port_id", 0) print("Warning: No device selected, using first available device") else: input_params["usb_port_id"] = 0 print("Warning: No connected devices, using default usb_port_id 0") print("Custom model input parameters:", input_params) # Stop existing inference worker if self.inference_worker: self.inference_worker.stop() self.inference_worker = None # Create new custom inference worker self.inference_worker = CustomInferenceWorkerThread( self.inference_queue, min_interval=0.5, mse_threshold=500 ) self.inference_worker.input_params = input_params self.inference_worker.inference_result_signal.connect( self.main_window.handle_inference_result ) self.inference_worker.start() print("Custom model inference worker started") # Start camera (custom model defaults to video mode) tool_type = tool_config.get("input_info", {}).get("type", "video") if tool_type == "video": if self._camera_was_active and self.main_window.media_controller.video_thread is not None: self.main_window.media_controller.video_thread.change_pixmap_signal.connect( self.main_window.media_controller.update_image ) print("Camera reconnected for video processing") else: self.main_window.media_controller.start_camera() return True except Exception as e: print(f"Error selecting custom model: {e}") import traceback print(traceback.format_exc()) return False