forked from masonhuang/KNEO-Academy
92 lines
3.6 KiB
Python
92 lines
3.6 KiB
Python
import os, time, queue, numpy as np, importlib.util
|
|
from PyQt5.QtCore import QThread, pyqtSignal
|
|
from src.config import UTILS_DIR
|
|
|
|
def load_inference_module(mode, model_name):
|
|
"""Dynamically load an inference module"""
|
|
script_path = os.path.join(UTILS_DIR, mode, model_name, "script.py")
|
|
module_name = f"{mode}_{model_name}"
|
|
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
class InferenceWorkerThread(QThread):
|
|
inference_result_signal = pyqtSignal(object)
|
|
|
|
def __init__(self, frame_queue, mode, model_name, min_interval=0.5, mse_threshold=500, once_mode=False):
|
|
super().__init__()
|
|
self.frame_queue = frame_queue
|
|
self.mode = mode
|
|
self.model_name = model_name
|
|
self.min_interval = min_interval
|
|
self.mse_threshold = mse_threshold
|
|
self._running = True
|
|
self.once_mode = once_mode
|
|
self.last_inference_time = 0
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
self.input_params = {}
|
|
|
|
# Dynamically load inference module
|
|
self.inference_module = load_inference_module(mode, model_name)
|
|
|
|
def run(self):
|
|
while self._running:
|
|
try:
|
|
frame = self.frame_queue.get(timeout=0.1)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
current_time = time.time()
|
|
if current_time - self.last_inference_time < self.min_interval:
|
|
continue
|
|
|
|
if self.last_frame is not None:
|
|
# 檢查當前幀與上一幀的尺寸是否相同
|
|
if frame.shape != self.last_frame.shape:
|
|
print(f"幀尺寸變更: 從 {self.last_frame.shape} 變更為 {frame.shape}")
|
|
# 尺寸不同時,重置上一幀和緩存結果
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
else:
|
|
# 只有在尺寸相同時才進行 MSE 計算
|
|
try:
|
|
mse = np.mean((frame.astype(np.float32) - self.last_frame.astype(np.float32)) ** 2)
|
|
if mse < self.mse_threshold and self.cached_result is not None:
|
|
# 只有在結果不為 None 時才發送信號
|
|
if self.cached_result is not None:
|
|
self.inference_result_signal.emit(self.cached_result)
|
|
if self.once_mode:
|
|
self._running = False
|
|
break
|
|
continue
|
|
except Exception as e:
|
|
print(f"計算 MSE 時發生錯誤: {e}")
|
|
# 發生錯誤時重置上一幀和緩存結果
|
|
self.last_frame = None
|
|
self.cached_result = None
|
|
|
|
try:
|
|
result = self.inference_module.inference(frame, params=self.input_params)
|
|
except Exception as e:
|
|
print(f"Inference error: {e}")
|
|
result = None
|
|
|
|
self.last_inference_time = current_time
|
|
self.last_frame = frame.copy()
|
|
self.cached_result = result
|
|
|
|
# 只有在結果不為 None 時才發送信號
|
|
if result is not None:
|
|
self.inference_result_signal.emit(result)
|
|
|
|
if self.once_mode:
|
|
self._running = False
|
|
break
|
|
|
|
self.quit()
|
|
|
|
def stop(self):
|
|
self._running = False
|
|
self.wait() |