""" ONNX???????? ??????ONNX???????????????? """ from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel, Field from typing import Optional, Dict, Any import os import uuid import logging from datetime import datetime import httpx from core import process_onnx_core from facade import WorkerFacade # ???????? logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="ONNX Processing Service", description="ONNX????????????", version="1.0.0" ) # CORS???? - ??????Web????????? app.add_middleware( CORSMiddleware, allow_origins=["http://localhost:4000"], # ??????Web????? allow_credentials=True, allow_methods=["GET", "POST", "DELETE"], allow_headers=["*"], ) # ???? UPLOAD_FOLDER = os.getenv("ONNX_UPLOAD_FOLDER", "/tmp/onnx_uploads") OUTPUT_FOLDER = os.getenv("ONNX_OUTPUT_FOLDER", "/tmp/onnx_outputs") MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB API_KEY = os.getenv("ONNX_API_KEY", "onnx-secret-key") # ?????????? os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(OUTPUT_FOLDER, exist_ok=True) worker = WorkerFacade( process_core=process_onnx_core, output_folder=OUTPUT_FOLDER, output_prefix="optimized_", output_extension=".onnx", step_name="onnx", logger=logger, ) # Pydantic???? class HealthResponse(BaseModel): service: str status: str timestamp: str active_tasks: int class FileUploadResponse(BaseModel): success: bool file_id: str file_path: str message: str class ONNXProcessRequest(BaseModel): file_id: str model_id: int = Field(..., ge=1, le=65535) version: str = Field(..., regex=r'^[0-9a-fA-F]{4}$') platform: str = Field(..., regex=r'^(520|720|530|630|730)$') class TaskStatusResponse(BaseModel): task_id: str status: str progress: float message: str result: Optional[Dict[str, Any]] = None created_at: str updated_at: str # ???????? def verify_api_key(x_api_key: str = None): """????PI Key""" if x_api_key != API_KEY: raise HTTPException( status_code=401, detail="Invalid API key" ) return True # ???????? @app.get("/health", response_model=HealthResponse) async def health_check(): """???????????""" active_tasks = worker.active_tasks_count() return HealthResponse( service="onnx-processing", status="healthy", timestamp=datetime.now().isoformat(), active_tasks=active_tasks ) # ??????? @app.post("/api/onnx/upload", response_model=FileUploadResponse) async def upload_onnx( file: UploadFile = File(...), api_key: str = Depends(verify_api_key) ): """???ONNX???????""" try: # ?????????? if not file.filename.lower().endswith((".onnx", ".tflite")): raise HTTPException( status_code=400, detail="Invalid file type. Only ONNX and TFLite files are supported." ) # ???????????? content = await file.read() if len(content) > MAX_FILE_SIZE: raise HTTPException( status_code=413, detail="File too large. Maximum size is 500MB." ) # ??????? filename = f"{uuid.uuid4()}_{file.filename}" file_path = os.path.join(UPLOAD_FOLDER, filename) with open(file_path, "wb") as f: f.write(content) logger.info(f"ONNX file uploaded: {filename}") return FileUploadResponse( success=True, file_id=filename, file_path=file_path, message="ONNX file uploaded successfully" ) except HTTPException: raise except Exception as e: logger.error(f"Error uploading ONNX file: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ONNX??? @app.post("/api/onnx/process") async def process_onnx( request: ONNXProcessRequest, background_tasks: BackgroundTasks, api_key: str = Depends(verify_api_key) ): """???ONNX???????""" try: # ??????????????? file_path = os.path.join(UPLOAD_FOLDER, request.file_id) if not os.path.exists(file_path): raise HTTPException(status_code=404, detail="ONNX file not found") task_id = worker.submit( parameters=request.dict(), input_paths={"file_path": file_path}, ) background_tasks.add_task(worker.run_task, task_id) logger.info(f"ONNX processing task submitted: {task_id}") return { "success": True, "task_id": task_id, "message": "ONNX optimization task submitted successfully" } except HTTPException: raise except Exception as e: logger.error(f"Error processing ONNX: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) # ?????????? @app.get("/api/onnx/tasks/{task_id}/status", response_model=TaskStatusResponse) async def get_task_status( task_id: str, api_key: str = Depends(verify_api_key) ): """????NNX???????""" task = worker.get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") return TaskStatusResponse( task_id=task_id, status=task["status"], progress=task["progress"], message=task["message"], result=task["result"], created_at=task["created_at"].isoformat(), updated_at=task["updated_at"].isoformat() ) # ??????????? @app.get("/api/onnx/tasks/{task_id}/result") async def get_task_result( task_id: str, api_key: str = Depends(verify_api_key) ): """????NNX???????""" task = worker.get_task(task_id) if not task: raise HTTPException(status_code=404, detail="Task not found") if task["status"] != "completed": raise HTTPException(status_code=400, detail="Task not completed yet") if "file_path" in task["result"]: return FileResponse( task["result"]["file_path"], media_type="application/octet-stream", filename=f"optimized_{task_id}.onnx" ) return task["result"] # ??????? @app.get("/api/onnx/tasks") async def list_tasks(api_key: str = Depends(verify_api_key)): """???????NNX????""" return { "tasks": [ { "task_id": task["task_id"], "status": task["status"], "progress": task["progress"], "created_at": task["created_at"].isoformat(), "updated_at": task["updated_at"].isoformat() } for task in worker.list_tasks() ] } # ???????? @app.delete("/api/onnx/tasks/{task_id}") async def cancel_task( task_id: str, api_key: str = Depends(verify_api_key) ): """????NNX????""" try: worker.cancel_task(task_id) except KeyError: raise HTTPException(status_code=404, detail="Task not found") except ValueError: raise HTTPException(status_code=400, detail="Task cannot be cancelled") logger.info(f"ONNX task cancelled: {task_id}") return { "success": True, "message": "ONNX task cancelled successfully" } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=5001)