""" Generic Redis Stream queue consumer for workers. 每個 Worker(ONNX/BIE/NEF)使用此模組作為進入點: 1. 從指定的 Redis Stream queue 拉取任務(XREADGROUP) 2. 從 S3/MinIO 下載輸入檔案到本地暫存目錄 3. 呼叫對應的 core function 處理 4. 將結果上傳到 S3/MinIO 5. 將結果推送到 queue:done """ import json import logging import os import shutil import signal import socket import tempfile import time from typing import Any, Callable, Dict import redis logger = logging.getLogger(__name__) class WorkerConsumer: """Redis Stream based queue consumer with S3/MinIO storage.""" def __init__( self, stage: str, process_fn: Callable[[Dict[str, str], str, Dict[str, Any]], Dict[str, Any]], queue_name: str, group_name: str, redis_url: str = None, job_data_dir: str = None, ): self.stage = stage self.process_fn = process_fn self.queue_name = queue_name self.group_name = group_name self.redis_url = redis_url or os.environ.get("REDIS_URL", "redis://localhost:6379") self.job_data_dir = job_data_dir or os.environ.get("JOB_DATA_DIR", "/data/jobs") self.consumer_name = f"{stage}-worker-{socket.gethostname()}-{os.getpid()}" self.running = True self.client = redis.Redis.from_url(self.redis_url, decode_responses=True) # Initialize MinIO storage self.minio = None if os.environ.get("STORAGE_BACKEND", "local") == "minio": from services.workers.s3_storage import MinIOStorage self.minio = MinIOStorage() logger.info("Using MinIO storage backend") else: logger.info("Using local filesystem storage backend") def _ensure_group(self): """Create consumer group if it doesn't exist.""" try: self.client.xgroup_create(self.queue_name, self.group_name, id="0", mkstream=True) logger.info(f"Created consumer group '{self.group_name}' on '{self.queue_name}'") except redis.ResponseError as e: if "BUSYGROUP" not in str(e): raise # Group already exists — OK def _prepare_local_dir(self, job_id: str) -> str: """Prepare a local working directory for the job. For S3 mode: downloads required files from S3 to a temp dir. For local mode: returns the existing job dir on shared volume. """ if not self.minio: return os.path.join(self.job_data_dir, job_id) # MinIO mode: use a local temp dir (isolated per worker, no shared volume conflict) local_dir = os.path.join(tempfile.gettempdir(), "kneron-jobs", f"{job_id}-{self.stage}") os.makedirs(local_dir, exist_ok=True) s3_prefix = f"jobs/{job_id}" if self.stage == "onnx": # Download input/ directory (model file + ref_images) self.minio.download_prefix(f"{s3_prefix}/input", os.path.join(local_dir, "input")) logger.info(f"Downloaded input files from S3 for job {job_id}") elif self.stage == "bie": # Download out.onnx from previous stage self.minio.download_file(f"{s3_prefix}/out.onnx", os.path.join(local_dir, "out.onnx")) # Download ref_images for quantization self.minio.download_prefix( f"{s3_prefix}/input/ref_images", os.path.join(local_dir, "input", "ref_images"), ) logger.info(f"Downloaded ONNX + ref_images from S3 for job {job_id}") elif self.stage == "nef": # Download out.bie from previous stage self.minio.download_file(f"{s3_prefix}/out.bie", os.path.join(local_dir, "out.bie")) logger.info(f"Downloaded BIE from S3 for job {job_id}") return local_dir def _upload_output(self, job_id: str, job_dir: str): """Upload the output file to S3 after processing.""" if not self.minio: return output_files = { "onnx": "out.onnx", "bie": "out.bie", "nef": "out.nef", } output_name = output_files[self.stage] local_path = os.path.join(job_dir, output_name) s3_key = f"jobs/{job_id}/{output_name}" if os.path.exists(local_path): self.minio.upload_file(local_path, s3_key) logger.info(f"Uploaded {output_name} to S3 for job {job_id}") def _cleanup_local(self, job_dir: str): """Clean up local temp directory after S3 upload.""" if not self.minio: return try: shutil.rmtree(job_dir, ignore_errors=True) logger.debug(f"Cleaned up local dir: {job_dir}") except Exception as e: logger.warning(f"Failed to clean up {job_dir}: {e}") def _build_input_paths(self, job_dir: str, parameters: dict) -> dict: """Build input_paths dict based on stage and job directory contents.""" input_dir = os.path.join(job_dir, "input") if self.stage == "onnx": # Find the single input file in input/ input_file = None if os.path.isdir(input_dir): for f in os.listdir(input_dir): fpath = os.path.join(input_dir, f) if os.path.isfile(fpath): input_file = fpath break if not input_file: raise FileNotFoundError(f"No input file found in {input_dir}") return {"file_path": input_file} elif self.stage == "bie": onnx_path = os.path.join(job_dir, "out.onnx") ref_images_dir = os.path.join(input_dir, "ref_images") return { "onnx_file_path": onnx_path, "data_dir": ref_images_dir, } elif self.stage == "nef": bie_path = os.path.join(job_dir, "out.bie") return {"bie_file_path": bie_path} else: raise ValueError(f"Unknown stage: {self.stage}") def _get_output_path(self, job_dir: str) -> str: """Get the expected output file path for this stage.""" output_files = { "onnx": "out.onnx", "bie": "out.bie", "nef": "out.nef", } return os.path.join(job_dir, output_files[self.stage]) def _push_done(self, job_id: str, result: str, reason: str = None): """Push a done event to queue:done.""" message = { "job_id": job_id, "step": self.stage, "result": result, "completed_at": time.strftime("%Y-%m-%dT%H:%M:%S%z"), } if reason: message["reason"] = reason self.client.xadd("queue:done", {"data": json.dumps(message)}) logger.info(f"Pushed done: job={job_id} step={self.stage} result={result}") def _process_message(self, message_id: str, data: dict): """Process a single task message.""" job_id = data["job_id"] parameters = data.get("parameters", {}) logger.info(f"Processing job {job_id} (stage={self.stage})") job_dir = None try: # Prepare local working directory (download from S3 if needed) job_dir = self._prepare_local_dir(job_id) input_paths = self._build_input_paths(job_dir, parameters) output_path = self._get_output_path(job_dir) # Add work_dir to parameters so core can set up toolchain paths parameters["work_dir"] = job_dir result = self.process_fn(input_paths, output_path, parameters) # Upload output to S3 self._upload_output(job_id, job_dir) logger.info(f"Job {job_id} completed: {result.get('file_path', 'N/A')}") self._push_done(job_id, "ok") except Exception as e: logger.error(f"Job {job_id} failed: {e}", exc_info=True) self._push_done(job_id, "fail", reason=str(e)) finally: # Clean up local temp files in S3 mode if job_dir: self._cleanup_local(job_dir) # ACK the message regardless of success/failure self.client.xack(self.queue_name, self.group_name, message_id) def run(self): """Main loop: pull tasks from queue and process them.""" self._ensure_group() logger.info( f"[{self.consumer_name}] Listening on {self.queue_name} " f"(group={self.group_name})" ) # Handle graceful shutdown def handle_signal(signum, frame): logger.info(f"[{self.consumer_name}] Received signal {signum}, shutting down...") self.running = False signal.signal(signal.SIGTERM, handle_signal) signal.signal(signal.SIGINT, handle_signal) while self.running: try: results = self.client.xreadgroup( self.group_name, self.consumer_name, {self.queue_name: ">"}, count=1, block=5000, # 5 second timeout ) if not results: continue for stream_name, messages in results: for message_id, fields in messages: data = json.loads(fields["data"]) self._process_message(message_id, data) except redis.ConnectionError: logger.error("Redis connection lost, retrying in 3s...") time.sleep(3) except Exception as e: logger.error(f"Unexpected error: {e}", exc_info=True) time.sleep(1) logger.info(f"[{self.consumer_name}] Stopped")