2026-01-28 06:16:04 +00:00

270 lines
7.5 KiB
Python

"""
ONNX????????
??????ONNX????????????????
"""
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from typing import Optional, Dict, Any
import os
import uuid
import logging
from datetime import datetime
import httpx
from core import process_onnx_core
from facade import WorkerFacade
# ????????
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="ONNX Processing Service",
description="ONNX????????????",
version="1.0.0"
)
# CORS???? - ??????Web?????????
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:4000"], # ??????Web?????
allow_credentials=True,
allow_methods=["GET", "POST", "DELETE"],
allow_headers=["*"],
)
# ????
UPLOAD_FOLDER = os.getenv("ONNX_UPLOAD_FOLDER", "/tmp/onnx_uploads")
OUTPUT_FOLDER = os.getenv("ONNX_OUTPUT_FOLDER", "/tmp/onnx_outputs")
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB
API_KEY = os.getenv("ONNX_API_KEY", "onnx-secret-key")
# ??????????
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
worker = WorkerFacade(
process_core=process_onnx_core,
output_folder=OUTPUT_FOLDER,
output_prefix="optimized_",
output_extension=".onnx",
step_name="onnx",
logger=logger,
)
# Pydantic????
class HealthResponse(BaseModel):
service: str
status: str
timestamp: str
active_tasks: int
class FileUploadResponse(BaseModel):
success: bool
file_id: str
file_path: str
message: str
class ONNXProcessRequest(BaseModel):
file_id: str
model_id: int = Field(..., ge=1, le=65535)
version: str = Field(..., regex=r'^[0-9a-fA-F]{4}$')
platform: str = Field(..., regex=r'^(520|720|530|630|730)$')
class TaskStatusResponse(BaseModel):
task_id: str
status: str
progress: float
message: str
result: Optional[Dict[str, Any]] = None
created_at: str
updated_at: str
# ????????
def verify_api_key(x_api_key: str = None):
"""????PI Key"""
if x_api_key != API_KEY:
raise HTTPException(
status_code=401,
detail="Invalid API key"
)
return True
# ????????
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""???????????"""
active_tasks = worker.active_tasks_count()
return HealthResponse(
service="onnx-processing",
status="healthy",
timestamp=datetime.now().isoformat(),
active_tasks=active_tasks
)
# ???????
@app.post("/api/onnx/upload", response_model=FileUploadResponse)
async def upload_onnx(
file: UploadFile = File(...),
api_key: str = Depends(verify_api_key)
):
"""???ONNX???????"""
try:
# ??????????
if not file.filename.lower().endswith((".onnx", ".tflite")):
raise HTTPException(
status_code=400,
detail="Invalid file type. Only ONNX and TFLite files are supported."
)
# ????????????
content = await file.read()
if len(content) > MAX_FILE_SIZE:
raise HTTPException(
status_code=413,
detail="File too large. Maximum size is 500MB."
)
# ???????
filename = f"{uuid.uuid4()}_{file.filename}"
file_path = os.path.join(UPLOAD_FOLDER, filename)
with open(file_path, "wb") as f:
f.write(content)
logger.info(f"ONNX file uploaded: {filename}")
return FileUploadResponse(
success=True,
file_id=filename,
file_path=file_path,
message="ONNX file uploaded successfully"
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error uploading ONNX file: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ONNX???
@app.post("/api/onnx/process")
async def process_onnx(
request: ONNXProcessRequest,
background_tasks: BackgroundTasks,
api_key: str = Depends(verify_api_key)
):
"""???ONNX???????"""
try:
# ???????????????
file_path = os.path.join(UPLOAD_FOLDER, request.file_id)
if not os.path.exists(file_path):
raise HTTPException(status_code=404, detail="ONNX file not found")
task_id = worker.submit(
parameters=request.dict(),
input_paths={"file_path": file_path},
)
background_tasks.add_task(worker.run_task, task_id)
logger.info(f"ONNX processing task submitted: {task_id}")
return {
"success": True,
"task_id": task_id,
"message": "ONNX optimization task submitted successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error processing ONNX: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# ??????????
@app.get("/api/onnx/tasks/{task_id}/status", response_model=TaskStatusResponse)
async def get_task_status(
task_id: str,
api_key: str = Depends(verify_api_key)
):
"""????NNX???????"""
task = worker.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
return TaskStatusResponse(
task_id=task_id,
status=task["status"],
progress=task["progress"],
message=task["message"],
result=task["result"],
created_at=task["created_at"].isoformat(),
updated_at=task["updated_at"].isoformat()
)
# ???????????
@app.get("/api/onnx/tasks/{task_id}/result")
async def get_task_result(
task_id: str,
api_key: str = Depends(verify_api_key)
):
"""????NNX???????"""
task = worker.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail="Task not found")
if task["status"] != "completed":
raise HTTPException(status_code=400, detail="Task not completed yet")
if "file_path" in task["result"]:
return FileResponse(
task["result"]["file_path"],
media_type="application/octet-stream",
filename=f"optimized_{task_id}.onnx"
)
return task["result"]
# ???????
@app.get("/api/onnx/tasks")
async def list_tasks(api_key: str = Depends(verify_api_key)):
"""???????NNX????"""
return {
"tasks": [
{
"task_id": task["task_id"],
"status": task["status"],
"progress": task["progress"],
"created_at": task["created_at"].isoformat(),
"updated_at": task["updated_at"].isoformat()
}
for task in worker.list_tasks()
]
}
# ????????
@app.delete("/api/onnx/tasks/{task_id}")
async def cancel_task(
task_id: str,
api_key: str = Depends(verify_api_key)
):
"""????NNX????"""
try:
worker.cancel_task(task_id)
except KeyError:
raise HTTPException(status_code=404, detail="Task not found")
except ValueError:
raise HTTPException(status_code=400, detail="Task cannot be cancelled")
logger.info(f"ONNX task cancelled: {task_id}")
return {
"success": True,
"message": "ONNX task cancelled successfully"
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=5001)