KNEO-Academy/src/models/inference_worker.py

92 lines
3.6 KiB
Python

import os, time, queue, numpy as np, importlib.util
from PyQt5.QtCore import QThread, pyqtSignal
from src.config import UTILS_DIR
def load_inference_module(mode, model_name):
"""Dynamically load an inference module"""
script_path = os.path.join(UTILS_DIR, mode, model_name, "script.py")
module_name = f"{mode}_{model_name}"
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class InferenceWorkerThread(QThread):
inference_result_signal = pyqtSignal(object)
def __init__(self, frame_queue, mode, model_name, min_interval=0.5, mse_threshold=500, once_mode=False):
super().__init__()
self.frame_queue = frame_queue
self.mode = mode
self.model_name = model_name
self.min_interval = min_interval
self.mse_threshold = mse_threshold
self._running = True
self.once_mode = once_mode
self.last_inference_time = 0
self.last_frame = None
self.cached_result = None
self.input_params = {}
# Dynamically load inference module
self.inference_module = load_inference_module(mode, model_name)
def run(self):
while self._running:
try:
frame = self.frame_queue.get(timeout=0.1)
except queue.Empty:
continue
current_time = time.time()
if current_time - self.last_inference_time < self.min_interval:
continue
if self.last_frame is not None:
# 檢查當前幀與上一幀的尺寸是否相同
if frame.shape != self.last_frame.shape:
print(f"幀尺寸變更: 從 {self.last_frame.shape} 變更為 {frame.shape}")
# 尺寸不同時,重置上一幀和緩存結果
self.last_frame = None
self.cached_result = None
else:
# 只有在尺寸相同時才進行 MSE 計算
try:
mse = np.mean((frame.astype(np.float32) - self.last_frame.astype(np.float32)) ** 2)
if mse < self.mse_threshold and self.cached_result is not None:
# 只有在結果不為 None 時才發送信號
if self.cached_result is not None:
self.inference_result_signal.emit(self.cached_result)
if self.once_mode:
self._running = False
break
continue
except Exception as e:
print(f"計算 MSE 時發生錯誤: {e}")
# 發生錯誤時重置上一幀和緩存結果
self.last_frame = None
self.cached_result = None
try:
result = self.inference_module.inference(frame, params=self.input_params)
except Exception as e:
print(f"Inference error: {e}")
result = None
self.last_inference_time = current_time
self.last_frame = frame.copy()
self.cached_result = result
# 只有在結果不為 None 時才發送信號
if result is not None:
self.inference_result_signal.emit(result)
if self.once_mode:
self._running = False
break
self.quit()
def stop(self):
self._running = False
self.wait()