KNEO-Academy/src/models/inference_worker.py
HuangMason320 17deba3bdb Refactor: Clean up codebase and improve documentation
- Remove unused files (model_controller.py, model_service.py, device_connection_popup.py)
- Clean up commented code in device_service.py, device_popup.py, config.py
- Update docstrings and comments across all modules
- Improve code organization and readability
2025-12-30 16:47:31 +08:00

149 lines
5.5 KiB
Python

"""
inference_worker.py - Inference Worker Thread
This module provides a QThread-based worker for running AI model inference
on video frames. It supports dynamic module loading and frame caching.
"""
import os
import time
import queue
import numpy as np
import importlib.util
from PyQt5.QtCore import QThread, pyqtSignal
from src.config import UTILS_DIR
def load_inference_module(mode, model_name):
"""
Dynamically load an inference module from the utils directory.
Args:
mode (str): The inference mode/category (e.g., 'object_detection').
model_name (str): The name of the model.
Returns:
module: The loaded Python module containing the inference function.
"""
script_path = os.path.join(UTILS_DIR, mode, model_name, "script.py")
module_name = f"{mode}_{model_name}"
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class InferenceWorkerThread(QThread):
"""
Worker thread for running AI model inference on frames.
This thread processes frames from a queue, runs inference using a
dynamically loaded module, and emits results via Qt signals.
Attributes:
inference_result_signal: Signal emitted with inference results
frame_queue: Queue containing frames to process
mode: Inference mode/category
model_name: Name of the model to use
min_interval: Minimum time between inferences (seconds)
mse_threshold: MSE threshold for frame change detection
once_mode: If True, stop after one inference
"""
inference_result_signal = pyqtSignal(object)
def __init__(self, frame_queue, mode, model_name, min_interval=0.5, mse_threshold=500, once_mode=False):
"""
Initialize the InferenceWorkerThread.
Args:
frame_queue: Queue containing frames to process.
mode (str): Inference mode/category.
model_name (str): Name of the model.
min_interval (float): Minimum seconds between inferences.
mse_threshold (float): MSE threshold for detecting frame changes.
once_mode (bool): If True, stop after processing one frame.
"""
super().__init__()
self.frame_queue = frame_queue
self.mode = mode
self.model_name = model_name
self.min_interval = min_interval
self.mse_threshold = mse_threshold
self._running = True
self.once_mode = once_mode
self.last_inference_time = 0
self.last_frame = None
self.cached_result = None
self.input_params = {}
# Dynamically load inference module
self.inference_module = load_inference_module(mode, model_name)
def run(self):
"""
Main thread execution loop.
Continuously processes frames from the queue, runs inference,
and emits results. Uses MSE-based frame change detection to
optimize performance by skipping similar frames.
"""
while self._running:
try:
frame = self.frame_queue.get(timeout=0.1)
except queue.Empty:
continue
current_time = time.time()
if current_time - self.last_inference_time < self.min_interval:
continue
if self.last_frame is not None:
# Check if current frame and previous frame have same dimensions
if frame.shape != self.last_frame.shape:
print(f"Frame size changed: from {self.last_frame.shape} to {frame.shape}")
# Reset last frame and cached result when dimensions differ
self.last_frame = None
self.cached_result = None
else:
# Only calculate MSE when dimensions are the same
try:
mse = np.mean((frame.astype(np.float32) - self.last_frame.astype(np.float32)) ** 2)
if mse < self.mse_threshold and self.cached_result is not None:
# Only emit signal if result is not None
if self.cached_result is not None:
self.inference_result_signal.emit(self.cached_result)
if self.once_mode:
self._running = False
break
continue
except Exception as e:
print(f"Error calculating MSE: {e}")
# Reset last frame and cached result on error
self.last_frame = None
self.cached_result = None
try:
result = self.inference_module.inference(frame, params=self.input_params)
except Exception as e:
print(f"Inference error: {e}")
result = None
self.last_inference_time = current_time
self.last_frame = frame.copy()
self.cached_result = result
# Only emit signal if result is not None
if result is not None:
self.inference_result_signal.emit(result)
if self.once_mode:
self._running = False
break
self.quit()
def stop(self):
"""Stop the inference worker thread and wait for it to finish."""
self._running = False
self.wait()