270 lines
7.5 KiB
Python
270 lines
7.5 KiB
Python
"""
|
|
ONNX????????
|
|
??????ONNX????????????????
|
|
"""
|
|
|
|
from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
import uuid
|
|
import logging
|
|
from datetime import datetime
|
|
import httpx
|
|
from core import process_onnx_core
|
|
from facade import WorkerFacade
|
|
|
|
# ????????
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(
|
|
title="ONNX Processing Service",
|
|
description="ONNX????????????",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# CORS???? - ??????Web?????????
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:4000"], # ??????Web?????
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "DELETE"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ????
|
|
UPLOAD_FOLDER = os.getenv("ONNX_UPLOAD_FOLDER", "/tmp/onnx_uploads")
|
|
OUTPUT_FOLDER = os.getenv("ONNX_OUTPUT_FOLDER", "/tmp/onnx_outputs")
|
|
MAX_FILE_SIZE = 500 * 1024 * 1024 # 500MB
|
|
API_KEY = os.getenv("ONNX_API_KEY", "onnx-secret-key")
|
|
|
|
# ??????????
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
|
|
|
worker = WorkerFacade(
|
|
process_core=process_onnx_core,
|
|
output_folder=OUTPUT_FOLDER,
|
|
output_prefix="optimized_",
|
|
output_extension=".onnx",
|
|
step_name="onnx",
|
|
logger=logger,
|
|
)
|
|
|
|
# Pydantic????
|
|
class HealthResponse(BaseModel):
|
|
service: str
|
|
status: str
|
|
timestamp: str
|
|
active_tasks: int
|
|
|
|
class FileUploadResponse(BaseModel):
|
|
success: bool
|
|
file_id: str
|
|
file_path: str
|
|
message: str
|
|
|
|
class ONNXProcessRequest(BaseModel):
|
|
file_id: str
|
|
model_id: int = Field(..., ge=1, le=65535)
|
|
version: str = Field(..., regex=r'^[0-9a-fA-F]{4}$')
|
|
platform: str = Field(..., regex=r'^(520|720|530|630|730)$')
|
|
|
|
class TaskStatusResponse(BaseModel):
|
|
task_id: str
|
|
status: str
|
|
progress: float
|
|
message: str
|
|
result: Optional[Dict[str, Any]] = None
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
# ????????
|
|
def verify_api_key(x_api_key: str = None):
|
|
"""????PI Key"""
|
|
if x_api_key != API_KEY:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid API key"
|
|
)
|
|
return True
|
|
|
|
# ????????
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""???????????"""
|
|
active_tasks = worker.active_tasks_count()
|
|
return HealthResponse(
|
|
service="onnx-processing",
|
|
status="healthy",
|
|
timestamp=datetime.now().isoformat(),
|
|
active_tasks=active_tasks
|
|
)
|
|
|
|
# ???????
|
|
@app.post("/api/onnx/upload", response_model=FileUploadResponse)
|
|
async def upload_onnx(
|
|
file: UploadFile = File(...),
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""???ONNX???????"""
|
|
try:
|
|
# ??????????
|
|
if not file.filename.lower().endswith((".onnx", ".tflite")):
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="Invalid file type. Only ONNX and TFLite files are supported."
|
|
)
|
|
|
|
# ????????????
|
|
content = await file.read()
|
|
if len(content) > MAX_FILE_SIZE:
|
|
raise HTTPException(
|
|
status_code=413,
|
|
detail="File too large. Maximum size is 500MB."
|
|
)
|
|
|
|
# ???????
|
|
filename = f"{uuid.uuid4()}_{file.filename}"
|
|
file_path = os.path.join(UPLOAD_FOLDER, filename)
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(content)
|
|
|
|
logger.info(f"ONNX file uploaded: {filename}")
|
|
|
|
return FileUploadResponse(
|
|
success=True,
|
|
file_id=filename,
|
|
file_path=file_path,
|
|
message="ONNX file uploaded successfully"
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error uploading ONNX file: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ONNX???
|
|
@app.post("/api/onnx/process")
|
|
async def process_onnx(
|
|
request: ONNXProcessRequest,
|
|
background_tasks: BackgroundTasks,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""???ONNX???????"""
|
|
try:
|
|
# ???????????????
|
|
file_path = os.path.join(UPLOAD_FOLDER, request.file_id)
|
|
if not os.path.exists(file_path):
|
|
raise HTTPException(status_code=404, detail="ONNX file not found")
|
|
|
|
task_id = worker.submit(
|
|
parameters=request.dict(),
|
|
input_paths={"file_path": file_path},
|
|
)
|
|
background_tasks.add_task(worker.run_task, task_id)
|
|
|
|
logger.info(f"ONNX processing task submitted: {task_id}")
|
|
|
|
return {
|
|
"success": True,
|
|
"task_id": task_id,
|
|
"message": "ONNX optimization task submitted successfully"
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error processing ONNX: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ??????????
|
|
@app.get("/api/onnx/tasks/{task_id}/status", response_model=TaskStatusResponse)
|
|
async def get_task_status(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????NNX???????"""
|
|
task = worker.get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return TaskStatusResponse(
|
|
task_id=task_id,
|
|
status=task["status"],
|
|
progress=task["progress"],
|
|
message=task["message"],
|
|
result=task["result"],
|
|
created_at=task["created_at"].isoformat(),
|
|
updated_at=task["updated_at"].isoformat()
|
|
)
|
|
|
|
# ???????????
|
|
@app.get("/api/onnx/tasks/{task_id}/result")
|
|
async def get_task_result(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????NNX???????"""
|
|
task = worker.get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task["status"] != "completed":
|
|
raise HTTPException(status_code=400, detail="Task not completed yet")
|
|
|
|
if "file_path" in task["result"]:
|
|
return FileResponse(
|
|
task["result"]["file_path"],
|
|
media_type="application/octet-stream",
|
|
filename=f"optimized_{task_id}.onnx"
|
|
)
|
|
return task["result"]
|
|
|
|
# ???????
|
|
@app.get("/api/onnx/tasks")
|
|
async def list_tasks(api_key: str = Depends(verify_api_key)):
|
|
"""???????NNX????"""
|
|
return {
|
|
"tasks": [
|
|
{
|
|
"task_id": task["task_id"],
|
|
"status": task["status"],
|
|
"progress": task["progress"],
|
|
"created_at": task["created_at"].isoformat(),
|
|
"updated_at": task["updated_at"].isoformat()
|
|
}
|
|
for task in worker.list_tasks()
|
|
]
|
|
}
|
|
|
|
# ????????
|
|
@app.delete("/api/onnx/tasks/{task_id}")
|
|
async def cancel_task(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????NNX????"""
|
|
try:
|
|
worker.cancel_task(task_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Task cannot be cancelled")
|
|
|
|
logger.info(f"ONNX task cancelled: {task_id}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "ONNX task cancelled successfully"
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=5001)
|