forked from masonhuang/KNEO-Academy
- Remove unused files (model_controller.py, model_service.py, device_connection_popup.py) - Clean up commented code in device_service.py, device_popup.py, config.py - Update docstrings and comments across all modules - Improve code organization and readability
149 lines
5.5 KiB
Python
149 lines
5.5 KiB
Python
"""
|
|
inference_worker.py - Inference Worker Thread
|
|
|
|
This module provides a QThread-based worker for running AI model inference
|
|
on video frames. It supports dynamic module loading and frame caching.
|
|
"""
|
|
|
|
import os
|
|
import time
|
|
import queue
|
|
import numpy as np
|
|
import importlib.util
|
|
from PyQt5.QtCore import QThread, pyqtSignal
|
|
from src.config import UTILS_DIR
|
|
|
|
|
|
def load_inference_module(mode, model_name):
|
|
"""
|
|
Dynamically load an inference module from the utils directory.
|
|
|
|
Args:
|
|
mode (str): The inference mode/category (e.g., 'object_detection').
|
|
model_name (str): The name of the model.
|
|
|
|
Returns:
|
|
module: The loaded Python module containing the inference function.
|
|
"""
|
|
script_path = os.path.join(UTILS_DIR, mode, model_name, "script.py")
|
|
module_name = f"{mode}_{model_name}"
|
|
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
class InferenceWorkerThread(QThread):
|
|
"""
|
|
Worker thread for running AI model inference on frames.
|
|
|
|
This thread processes frames from a queue, runs inference using a
|
|
dynamically loaded module, and emits results via Qt signals.
|
|
|
|
Attributes:
|
|
inference_result_signal: Signal emitted with inference results
|
|
frame_queue: Queue containing frames to process
|
|
mode: Inference mode/category
|
|
model_name: Name of the model to use
|
|
min_interval: Minimum time between inferences (seconds)
|
|
mse_threshold: MSE threshold for frame change detection
|
|
once_mode: If True, stop after one inference
|
|
"""
|
|
|
|
inference_result_signal = pyqtSignal(object)
|
|
|
|
def __init__(self, frame_queue, mode, model_name, min_interval=0.5, mse_threshold=500, once_mode=False):
|
|
"""
|
|
Initialize the InferenceWorkerThread.
|
|
|
|
Args:
|
|
frame_queue: Queue containing frames to process.
|
|
mode (str): Inference mode/category.
|
|
model_name (str): Name of the model.
|
|
min_interval (float): Minimum seconds between inferences.
|
|
mse_threshold (float): MSE threshold for detecting frame changes.
|
|
once_mode (bool): If True, stop after processing one frame.
|
|
"""
|
|
super().__init__()
|
|
self.frame_queue = frame_queue
|
|
self.mode = mode
|
|
self.model_name = model_name
|
|
self.min_interval = min_interval
|
|
self.mse_threshold = mse_threshold
|
|
self._running = True
|
|
self.once_mode = once_mode
|
|
self.last_inference_time = 0
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
self.input_params = {}
|
|
|
|
# Dynamically load inference module
|
|
self.inference_module = load_inference_module(mode, model_name)
|
|
|
|
def run(self):
|
|
"""
|
|
Main thread execution loop.
|
|
|
|
Continuously processes frames from the queue, runs inference,
|
|
and emits results. Uses MSE-based frame change detection to
|
|
optimize performance by skipping similar frames.
|
|
"""
|
|
while self._running:
|
|
try:
|
|
frame = self.frame_queue.get(timeout=0.1)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
current_time = time.time()
|
|
if current_time - self.last_inference_time < self.min_interval:
|
|
continue
|
|
|
|
if self.last_frame is not None:
|
|
# Check if current frame and previous frame have same dimensions
|
|
if frame.shape != self.last_frame.shape:
|
|
print(f"Frame size changed: from {self.last_frame.shape} to {frame.shape}")
|
|
# Reset last frame and cached result when dimensions differ
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
else:
|
|
# Only calculate MSE when dimensions are the same
|
|
try:
|
|
mse = np.mean((frame.astype(np.float32) - self.last_frame.astype(np.float32)) ** 2)
|
|
if mse < self.mse_threshold and self.cached_result is not None:
|
|
# Only emit signal if result is not None
|
|
if self.cached_result is not None:
|
|
self.inference_result_signal.emit(self.cached_result)
|
|
if self.once_mode:
|
|
self._running = False
|
|
break
|
|
continue
|
|
except Exception as e:
|
|
print(f"Error calculating MSE: {e}")
|
|
# Reset last frame and cached result on error
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
|
|
try:
|
|
result = self.inference_module.inference(frame, params=self.input_params)
|
|
except Exception as e:
|
|
print(f"Inference error: {e}")
|
|
result = None
|
|
|
|
self.last_inference_time = current_time
|
|
self.last_frame = frame.copy()
|
|
self.cached_result = result
|
|
|
|
# Only emit signal if result is not None
|
|
if result is not None:
|
|
self.inference_result_signal.emit(result)
|
|
|
|
if self.once_mode:
|
|
self._running = False
|
|
break
|
|
|
|
self.quit()
|
|
|
|
def stop(self):
|
|
"""Stop the inference worker thread and wait for it to finish."""
|
|
self._running = False
|
|
self.wait() |