221 lines
6.2 KiB
Python
221 lines
6.2 KiB
Python
"""
|
|
BIE???????????
|
|
??????BIE????????????????
|
|
"""
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import FileResponse
|
|
from pydantic import BaseModel, Field
|
|
from typing import Optional, Dict, Any
|
|
import os
|
|
import logging
|
|
from datetime import datetime
|
|
from core import process_bie_core
|
|
from facade import WorkerFacade
|
|
|
|
# ????????
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(
|
|
title="BIE Analysis Service",
|
|
description="BIE???????????",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# CORS???? - ??????Web?????????
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:4000"], # ??????Web?????
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "DELETE"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# ????
|
|
UPLOAD_FOLDER = os.getenv("BIE_UPLOAD_FOLDER", "/tmp/bie_uploads")
|
|
OUTPUT_FOLDER = os.getenv("BIE_OUTPUT_FOLDER", "/tmp/bie_outputs")
|
|
API_KEY = os.getenv("BIE_API_KEY", "bie-secret-key")
|
|
|
|
# ??????????
|
|
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
|
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
|
|
|
|
worker = WorkerFacade(
|
|
process_core=process_bie_core,
|
|
output_folder=OUTPUT_FOLDER,
|
|
output_prefix="analysis_",
|
|
output_extension=".bie",
|
|
step_name="bie",
|
|
logger=logger,
|
|
)
|
|
|
|
# Pydantic????
|
|
class HealthResponse(BaseModel):
|
|
service: str
|
|
status: str
|
|
timestamp: str
|
|
active_tasks: int
|
|
|
|
class BIEProcessRequest(BaseModel):
|
|
onnx_file_id: str
|
|
model_id: int = Field(..., ge=1, le=65535)
|
|
version: str = Field(..., regex=r'^[0-9a-fA-F]{4}$')
|
|
platform: str = Field(..., regex=r'^(520|720|530|630|730)$')
|
|
data_dir: str = Field(..., min_length=1)
|
|
|
|
class TaskStatusResponse(BaseModel):
|
|
task_id: str
|
|
status: str
|
|
progress: float
|
|
message: str
|
|
result: Optional[Dict[str, Any]] = None
|
|
created_at: str
|
|
updated_at: str
|
|
|
|
# ????????
|
|
def verify_api_key(x_api_key: str = None):
|
|
"""????PI Key"""
|
|
if x_api_key != API_KEY:
|
|
raise HTTPException(
|
|
status_code=401,
|
|
detail="Invalid API key"
|
|
)
|
|
return True
|
|
|
|
# ????????
|
|
@app.get("/health", response_model=HealthResponse)
|
|
async def health_check():
|
|
"""???????????"""
|
|
active_tasks = worker.active_tasks_count()
|
|
return HealthResponse(
|
|
service="bie-analysis",
|
|
status="healthy",
|
|
timestamp=datetime.now().isoformat(),
|
|
active_tasks=active_tasks
|
|
)
|
|
|
|
# BIE???
|
|
@app.post("/api/bie/process")
|
|
async def process_bie(
|
|
request: BIEProcessRequest,
|
|
background_tasks: BackgroundTasks,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""???BIE??????"""
|
|
try:
|
|
# ????ONNX???????????
|
|
onnx_file_path = os.path.join(UPLOAD_FOLDER, request.onnx_file_id)
|
|
if not os.path.exists(onnx_file_path):
|
|
raise HTTPException(status_code=404, detail="ONNX file not found")
|
|
|
|
# ?????????????????
|
|
if not os.path.exists(request.data_dir):
|
|
raise HTTPException(status_code=404, detail="Data directory not found")
|
|
|
|
task_id = worker.submit(
|
|
parameters=request.dict(),
|
|
input_paths={"onnx_file_path": onnx_file_path, "data_dir": request.data_dir},
|
|
)
|
|
background_tasks.add_task(worker.run_task, task_id)
|
|
|
|
logger.info(f"BIE analysis task submitted: {task_id}")
|
|
|
|
return {
|
|
"success": True,
|
|
"task_id": task_id,
|
|
"message": "BIE analysis task submitted successfully"
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error processing BIE: {str(e)}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ??????????
|
|
@app.get("/api/bie/tasks/{task_id}/status", response_model=TaskStatusResponse)
|
|
async def get_task_status(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????IE???????"""
|
|
task = worker.get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
return TaskStatusResponse(
|
|
task_id=task_id,
|
|
status=task["status"],
|
|
progress=task["progress"],
|
|
message=task["message"],
|
|
result=task["result"],
|
|
created_at=task["created_at"].isoformat(),
|
|
updated_at=task["updated_at"].isoformat()
|
|
)
|
|
|
|
# ???????????
|
|
@app.get("/api/bie/tasks/{task_id}/result")
|
|
async def get_task_result(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????IE???????"""
|
|
task = worker.get_task(task_id)
|
|
if not task:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
if task["status"] != "completed":
|
|
raise HTTPException(status_code=400, detail="Task not completed yet")
|
|
|
|
if "file_path" in task["result"]:
|
|
return FileResponse(
|
|
task["result"]["file_path"],
|
|
media_type="application/octet-stream",
|
|
filename=f"analysis_{task_id}.bie"
|
|
)
|
|
return task["result"]
|
|
|
|
# ???????
|
|
@app.get("/api/bie/tasks")
|
|
async def list_tasks(api_key: str = Depends(verify_api_key)):
|
|
"""???????IE????"""
|
|
return {
|
|
"tasks": [
|
|
{
|
|
"task_id": task["task_id"],
|
|
"status": task["status"],
|
|
"progress": task["progress"],
|
|
"created_at": task["created_at"].isoformat(),
|
|
"updated_at": task["updated_at"].isoformat()
|
|
}
|
|
for task in worker.list_tasks()
|
|
]
|
|
}
|
|
|
|
# ????????
|
|
@app.delete("/api/bie/tasks/{task_id}")
|
|
async def cancel_task(
|
|
task_id: str,
|
|
api_key: str = Depends(verify_api_key)
|
|
):
|
|
"""????IE????"""
|
|
try:
|
|
worker.cancel_task(task_id)
|
|
except KeyError:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail="Task cannot be cancelled")
|
|
|
|
logger.info(f"BIE task cancelled: {task_id}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "BIE task cancelled successfully"
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=5002)
|