feat(local-tool): 推論功能完整搬入 — flash 模組 + workspace 推論介面
## 後端(Phase 1) 新增 flash 模組(從 edge-ai-platform 搬入): - server/internal/flash/service.go:StartFlash + 模型相容性檢查 + 晶片 NEF 解析 - server/internal/flash/progress.go:Flash 進度追蹤器 - server/internal/api/ws/flash_ws.go:WebSocket 推送 flash 進度 - device_handler.go:新增 FlashDevice method + flashSvc 欄位 - router.go:新增 POST /api/devices/:id/flash + WS /ws/devices/:id/flash-progress - main.go:初始化 flash.NewService 並傳入 router 推論/攝影機/MJPEG/inference WebSocket 之前 M1 已搬好,不需改動。 Python bridge (kneron_bridge.py) 與 edge-ai-platform 完全相同,不需改動。 ## 前端 store + hooks(Phase 2) - stores/flash-store.ts(新):Zustand store — startFlash / updateProgress / retryFlash / reset - hooks/use-flash-progress.ts(新):WebSocket hook 接收 flash 進度 inference-store / camera-store / inference types / use-inference-stream / use-websocket 之前 M1 已搬好,不需改動。 ## 前端 UI 元件(Phase 3) - components/devices/flash-dialog.tsx(新):模型載入對話框 + 硬體相容性檢查 - components/devices/flash-progress.tsx(新):Flash 進度條 + 錯誤重試 camera-inference-view / camera-feed / camera-overlay / source-selector / inference-panel / performance-metrics / classification-result / confidence-slider / video-progress / batch-image-thumbnails 之前 M1 已搬好。 ## 前端頁面整合(Phase 4) - workspace/page.tsx:繁中硬編碼、顯示已載入模型名稱 - workspace/[deviceId]/workspace-client.tsx:加入 FlashDialog 按鈕 + 繁中硬編碼 - devices/[id]/device-detail-client.tsx:加入 FlashDialog + 「進入工作區」按鈕(模型已載入才顯示) - device-card.tsx:已連線 + 模型已載入時顯示「工作區」快捷按鈕 ## 使用者操作流程 裝置列表 → 連線 → 管理 → 載入模型 → 進入工作區 → 選攝影機/圖片/影片 → 開始推論 → 看 bounding box / FPS / latency 或:裝置列表 → 工作區(已有模型)→ 直接推論 ## 不搬的東西 - cluster/* 全部不搬(已砍 cluster 功能) - relay / tunnel 相關不搬 Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
819885c85d
commit
44711753ae
@ -7,6 +7,7 @@ import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { DeviceStatusBadge } from '@/components/devices/device-status';
|
||||
import { FlashDialog } from '@/components/devices/flash-dialog';
|
||||
import { DeviceHealthCard } from '@/components/devices/device-health-card';
|
||||
import { DeviceConnectionLog } from '@/components/devices/device-connection-log';
|
||||
import { DeviceSettingsCard } from '@/components/devices/device-settings-card';
|
||||
@ -57,9 +58,12 @@ export default function DeviceDetailClient() {
|
||||
<div className="flex gap-2">
|
||||
{isConnected ? (
|
||||
<>
|
||||
<FlashDialog deviceId={id} />
|
||||
{selectedDevice.flashedModel && (
|
||||
<Link href={`/workspace/${id}`}>
|
||||
<Button variant="outline" data-tour-id="open-workspace-btn">{t('devices.detail.openWorkspace')}</Button>
|
||||
</Link>
|
||||
)}
|
||||
<Button variant="ghost" onClick={() => disconnectDevice(id)}>
|
||||
{t('common.disconnect')}
|
||||
</Button>
|
||||
|
||||
@ -5,16 +5,15 @@ import Link from 'next/link';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { CameraInferenceView } from '@/components/camera/camera-inference-view';
|
||||
import { InferencePanel } from '@/components/inference/inference-panel';
|
||||
import { FlashDialog } from '@/components/devices/flash-dialog';
|
||||
import { useDeviceStore } from '@/stores/device-store';
|
||||
import { useInferenceStore } from '@/stores/inference-store';
|
||||
import { useInferenceStream } from '@/hooks/use-inference-stream';
|
||||
import { useCameraStore } from '@/stores/camera-store';
|
||||
import { useResolvedParams } from '@/hooks/use-resolved-params';
|
||||
import { api } from '@/lib/api';
|
||||
import { useTranslation } from '@/lib/i18n';
|
||||
|
||||
export default function WorkspaceClient() {
|
||||
const { t } = useTranslation();
|
||||
const { deviceId } = useResolvedParams();
|
||||
const { selectedDevice, fetchDevice } = useDeviceStore();
|
||||
const { isRunning, setRunning, reset } = useInferenceStore();
|
||||
@ -60,27 +59,30 @@ export default function WorkspaceClient() {
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center gap-2">
|
||||
<Link href={`/devices/${deviceId}`}>
|
||||
<Button variant="ghost" size="sm">{'← ' + t('common.back')}</Button>
|
||||
<Button variant="ghost" size="sm">{'← 返回'}</Button>
|
||||
</Link>
|
||||
<h1 className="text-xl font-bold">
|
||||
{t('inference.workspace') + ':'} {selectedDevice?.name || deviceId}
|
||||
{'工作區:'} {selectedDevice?.name || deviceId}
|
||||
</h1>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<FlashDialog deviceId={deviceId} />
|
||||
{/* Only show manual inference controls in camera mode */}
|
||||
{!isMediaMode && (
|
||||
<div className="flex gap-2">
|
||||
<>
|
||||
{isRunning ? (
|
||||
<Button variant="destructive" onClick={handleStopInference}>
|
||||
{t('inference.stopInference')}
|
||||
停止推論
|
||||
</Button>
|
||||
) : (
|
||||
<Button onClick={handleStartInference} disabled={!isStreaming} data-tour-id="start-inference-btn">
|
||||
{t('inference.startInference')}
|
||||
開始推論
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex gap-6">
|
||||
<div className="flex-1">
|
||||
|
||||
@ -1,15 +1,12 @@
|
||||
'use client';
|
||||
|
||||
// TODO: M2 redesign workspace landing (device picker, empty state)
|
||||
import Link from 'next/link';
|
||||
import { useEffect } from 'react';
|
||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { useDeviceStore } from '@/stores/device-store';
|
||||
import { useTranslation } from '@/lib/i18n';
|
||||
|
||||
export default function WorkspaceIndexPage() {
|
||||
const { t } = useTranslation();
|
||||
const { devices, fetchDevices } = useDeviceStore();
|
||||
|
||||
useEffect(() => {
|
||||
@ -23,21 +20,21 @@ export default function WorkspaceIndexPage() {
|
||||
return (
|
||||
<div className="space-y-6">
|
||||
<div>
|
||||
<h1 className="text-2xl font-bold">{t('workspace.title')}</h1>
|
||||
<p className="text-muted-foreground">{t('workspace.subtitle')}</p>
|
||||
<h1 className="text-2xl font-bold">工作區</h1>
|
||||
<p className="text-muted-foreground">選擇已連線的裝置開始推論</p>
|
||||
</div>
|
||||
|
||||
{connected.length === 0 ? (
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle className="text-base">{t('workspace.noConnectedDevice')}</CardTitle>
|
||||
<CardTitle className="text-base">沒有已連線的裝置</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground mb-4">
|
||||
{t('workspace.noConnectedDeviceDesc')}
|
||||
請先前往裝置頁面連接裝置,再回到工作區開始推論。
|
||||
</p>
|
||||
<Link href="/devices">
|
||||
<Button>{t('workspace.goToDevices')}</Button>
|
||||
<Button>前往裝置管理</Button>
|
||||
</Link>
|
||||
</CardContent>
|
||||
</Card>
|
||||
@ -51,6 +48,9 @@ export default function WorkspaceIndexPage() {
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<p className="text-sm text-muted-foreground">{d.type}</p>
|
||||
{d.flashedModel && (
|
||||
<p className="text-sm font-medium mt-1">{d.flashedModel}</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</Link>
|
||||
|
||||
@ -63,6 +63,13 @@ export function DeviceCard({ device, isFirstCard }: DeviceCardProps) {
|
||||
{t('common.manage')}
|
||||
</Button>
|
||||
</Link>
|
||||
{device.flashedModel && (
|
||||
<Link href={`/workspace/${device.id}`}>
|
||||
<Button size="sm" variant="default" disabled={isBusy}>
|
||||
工作區
|
||||
</Button>
|
||||
</Link>
|
||||
)}
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
|
||||
140
local-tool/frontend/src/components/devices/flash-dialog.tsx
Normal file
140
local-tool/frontend/src/components/devices/flash-dialog.tsx
Normal file
@ -0,0 +1,140 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useMemo, useState } from 'react';
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from '@/components/ui/dialog';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from '@/components/ui/select';
|
||||
import { TriangleAlertIcon } from 'lucide-react';
|
||||
import { FlashProgress } from './flash-progress';
|
||||
import { useModelStore } from '@/stores/model-store';
|
||||
import { useFlashStore } from '@/stores/flash-store';
|
||||
import { useDeviceStore } from '@/stores/device-store';
|
||||
import { useFlashProgress } from '@/hooks/use-flash-progress';
|
||||
import { isModelCompatible, getHardwareType } from '@/lib/hardware-compat';
|
||||
|
||||
interface FlashDialogProps {
|
||||
deviceId: string;
|
||||
}
|
||||
|
||||
export function FlashDialog({ deviceId }: FlashDialogProps) {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [selectedModelId, setSelectedModelId] = useState('');
|
||||
const { models, fetchModels } = useModelStore();
|
||||
const { isFlashing, progress, error, startFlash, retryFlash, reset } = useFlashStore();
|
||||
const { fetchDevice, devices } = useDeviceStore();
|
||||
const { connectAndWait, disconnect } = useFlashProgress(deviceId);
|
||||
|
||||
const device = devices.find((d) => d.id === deviceId);
|
||||
const selectedModel = models.find((m) => m.id === selectedModelId);
|
||||
|
||||
const compatible = useMemo(() => {
|
||||
if (!selectedModel || !device) return true;
|
||||
return isModelCompatible(selectedModel.supportedHardware, device.type);
|
||||
}, [selectedModel, device]);
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
fetchModels();
|
||||
reset();
|
||||
setSelectedModelId('');
|
||||
} else {
|
||||
disconnect();
|
||||
}
|
||||
}, [open, fetchModels, reset, disconnect]);
|
||||
|
||||
const handleFlash = async () => {
|
||||
if (!selectedModelId) return;
|
||||
// 1. Create WebSocket and wait for it to open
|
||||
await connectAndWait();
|
||||
// 2. Then start flash (POST) — now WS is listening
|
||||
await startFlash(deviceId, selectedModelId);
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={(v) => {
|
||||
if (!v && isFlashing && !error) return;
|
||||
setOpen(v);
|
||||
}}>
|
||||
<DialogTrigger asChild>
|
||||
<Button data-tour-id="flash-model-btn">Flash Model</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>Flash Model to Device</DialogTitle>
|
||||
</DialogHeader>
|
||||
<div className="space-y-4">
|
||||
{!isFlashing && !progress && !error ? (
|
||||
<>
|
||||
<Select value={selectedModelId} onValueChange={setSelectedModelId}>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select a model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{models.map((m) => (
|
||||
<SelectItem key={m.id} value={m.id}>
|
||||
{m.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
|
||||
{selectedModelId && !compatible && (
|
||||
<div className="rounded-md bg-yellow-50 p-3 border border-yellow-200">
|
||||
<div className="flex gap-2 items-start">
|
||||
<TriangleAlertIcon className="h-5 w-5 text-yellow-600 mt-0.5 shrink-0" />
|
||||
<div className="text-sm">
|
||||
<p className="font-medium text-yellow-800">Hardware Incompatible</p>
|
||||
<p className="text-yellow-700">
|
||||
This model may not be compatible with {device ? getHardwareType(device.type) : 'this device'}.
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Button
|
||||
onClick={handleFlash}
|
||||
disabled={!selectedModelId || !compatible}
|
||||
className="w-full"
|
||||
>
|
||||
{!selectedModelId
|
||||
? 'Select a model'
|
||||
: !compatible
|
||||
? 'Incompatible — Cannot Flash'
|
||||
: 'Start Flash'}
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<FlashProgress progress={progress} error={error} onRetry={retryFlash} />
|
||||
)}
|
||||
{((progress && progress.percent >= 100) || error) && (
|
||||
<Button
|
||||
variant="outline"
|
||||
className="w-full"
|
||||
onClick={() => {
|
||||
if (!error) fetchDevice(deviceId);
|
||||
disconnect();
|
||||
reset();
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
{error ? 'Close' : 'Done'}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
'use client';
|
||||
|
||||
import { Progress } from '@/components/ui/progress';
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { XCircle } from 'lucide-react';
|
||||
import type { FlashProgress as FlashProgressType } from '@/types/device';
|
||||
|
||||
interface FlashProgressProps {
|
||||
progress: FlashProgressType | null;
|
||||
error?: string | null;
|
||||
onRetry?: () => void;
|
||||
}
|
||||
|
||||
export function FlashProgress({ progress, error, onRetry }: FlashProgressProps) {
|
||||
if (error) {
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div className="rounded-md bg-red-50 p-4 border border-red-200">
|
||||
<div className="flex gap-2 items-start">
|
||||
<XCircle className="h-5 w-5 text-red-600 mt-0.5 shrink-0" />
|
||||
<div className="flex-1">
|
||||
<p className="font-medium text-red-800">Flash Failed</p>
|
||||
<p className="text-sm text-red-700 mt-1">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{onRetry && (
|
||||
<Button onClick={onRetry} className="w-full" variant="outline">
|
||||
Retry
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!progress) {
|
||||
return (
|
||||
<div className="space-y-2 text-center">
|
||||
<div className="animate-pulse text-sm text-muted-foreground">
|
||||
Preparing flash...
|
||||
</div>
|
||||
<Progress value={0} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-3">
|
||||
<div className="flex items-center justify-between text-sm">
|
||||
<span className="font-medium">{progress.stage}</span>
|
||||
<span className="text-muted-foreground">{progress.percent}%</span>
|
||||
</div>
|
||||
<Progress value={progress.percent} />
|
||||
<p className="text-sm text-muted-foreground">{progress.message}</p>
|
||||
{progress.percent >= 100 && (
|
||||
<p className="text-sm font-medium text-green-600">Flash Complete!</p>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
68
local-tool/frontend/src/hooks/use-flash-progress.ts
Normal file
68
local-tool/frontend/src/hooks/use-flash-progress.ts
Normal file
@ -0,0 +1,68 @@
|
||||
'use client';
|
||||
|
||||
import { useEffect, useRef, useCallback } from 'react';
|
||||
import { createWebSocket } from '@/lib/ws';
|
||||
import { useFlashStore } from '@/stores/flash-store';
|
||||
import type { FlashProgress } from '@/types/device';
|
||||
|
||||
/**
|
||||
* Manages flash progress WebSocket.
|
||||
* Returns a `connectAndWait` callback that creates the WebSocket and
|
||||
* returns a promise that resolves once the WS is open.
|
||||
*/
|
||||
export function useFlashProgress(deviceId: string) {
|
||||
const updateProgress = useFlashStore((s) => s.updateProgress);
|
||||
const wsRef = useRef<ReturnType<typeof createWebSocket> | null>(null);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
wsRef.current?.close();
|
||||
wsRef.current = null;
|
||||
};
|
||||
}, [deviceId]);
|
||||
|
||||
/**
|
||||
* Creates the WebSocket connection and returns a promise that resolves
|
||||
* once the connection is open. This is called imperatively (not via
|
||||
* useEffect) to avoid React render-cycle timing issues.
|
||||
*/
|
||||
const connectAndWait = useCallback(
|
||||
() =>
|
||||
new Promise<void>((resolve) => {
|
||||
// Close any existing connection
|
||||
wsRef.current?.close();
|
||||
|
||||
let resolved = false;
|
||||
const doResolve = () => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
|
||||
const ws = createWebSocket(
|
||||
`/ws/devices/${deviceId}/flash-progress`,
|
||||
(data) => {
|
||||
updateProgress(data as FlashProgress);
|
||||
},
|
||||
() => {
|
||||
doResolve();
|
||||
},
|
||||
);
|
||||
wsRef.current = ws;
|
||||
|
||||
// Safety timeout — don't block forever
|
||||
setTimeout(doResolve, 3000);
|
||||
}),
|
||||
[deviceId, updateProgress],
|
||||
);
|
||||
|
||||
/** Close the WebSocket connection. */
|
||||
const disconnect = useCallback(() => {
|
||||
wsRef.current?.close();
|
||||
wsRef.current = null;
|
||||
}, []);
|
||||
|
||||
return { connectAndWait, disconnect };
|
||||
}
|
||||
64
local-tool/frontend/src/stores/flash-store.ts
Normal file
64
local-tool/frontend/src/stores/flash-store.ts
Normal file
@ -0,0 +1,64 @@
|
||||
import { create } from 'zustand';
|
||||
import { api } from '@/lib/api';
|
||||
import type { FlashProgress } from '@/types/device';
|
||||
import { showApiError } from '@/lib/toast';
|
||||
import { useActivityStore } from './activity-store';
|
||||
|
||||
interface FlashState {
|
||||
isFlashing: boolean;
|
||||
progress: FlashProgress | null;
|
||||
error: string | null;
|
||||
lastFlashParams: { deviceId: string; modelId: string } | null;
|
||||
startFlash: (deviceId: string, modelId: string) => Promise<void>;
|
||||
updateProgress: (progress: FlashProgress) => void;
|
||||
setError: (error: string) => void;
|
||||
retryFlash: () => Promise<void>;
|
||||
reset: () => void;
|
||||
}
|
||||
|
||||
export const useFlashStore = create<FlashState>((set, get) => ({
|
||||
isFlashing: false,
|
||||
progress: null,
|
||||
error: null,
|
||||
lastFlashParams: null,
|
||||
|
||||
startFlash: async (deviceId, modelId) => {
|
||||
set({ isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
|
||||
const res = await api.post(`/devices/${deviceId}/flash`, { modelId });
|
||||
if (!res.success) {
|
||||
const errorMsg = res.error?.message || 'Flash failed';
|
||||
showApiError(res.error);
|
||||
set({ isFlashing: false, error: errorMsg });
|
||||
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${errorMsg}`);
|
||||
} else {
|
||||
useActivityStore.getState().addActivity('flash_start', 'Flash started');
|
||||
}
|
||||
},
|
||||
|
||||
updateProgress: (progress) => {
|
||||
if (progress.error) {
|
||||
set({ isFlashing: false, error: progress.error });
|
||||
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`);
|
||||
return;
|
||||
}
|
||||
set({ progress });
|
||||
if (progress.percent >= 100) {
|
||||
set({ isFlashing: false });
|
||||
useActivityStore.getState().addActivity('flash_complete', 'Flash completed');
|
||||
}
|
||||
},
|
||||
|
||||
setError: (error) => {
|
||||
set({ error, isFlashing: false });
|
||||
},
|
||||
|
||||
retryFlash: async () => {
|
||||
const { lastFlashParams } = get();
|
||||
if (!lastFlashParams) return;
|
||||
await get().startFlash(lastFlashParams.deviceId, lastFlashParams.modelId);
|
||||
},
|
||||
|
||||
reset: () => {
|
||||
set({ isFlashing: false, progress: null, error: null, lastFlashParams: null });
|
||||
},
|
||||
}));
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"visiona-local/server/internal/api/ws"
|
||||
"visiona-local/server/internal/device"
|
||||
"visiona-local/server/internal/driver"
|
||||
"visiona-local/server/internal/flash"
|
||||
"visiona-local/server/internal/inference"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@ -15,17 +16,20 @@ import (
|
||||
|
||||
type DeviceHandler struct {
|
||||
deviceMgr *device.Manager
|
||||
flashSvc *flash.Service
|
||||
inferenceSvc *inference.Service
|
||||
wsHub *ws.Hub
|
||||
}
|
||||
|
||||
func NewDeviceHandler(
|
||||
deviceMgr *device.Manager,
|
||||
flashSvc *flash.Service,
|
||||
inferenceSvc *inference.Service,
|
||||
wsHub *ws.Hub,
|
||||
) *DeviceHandler {
|
||||
return &DeviceHandler{
|
||||
deviceMgr: deviceMgr,
|
||||
flashSvc: flashSvc,
|
||||
inferenceSvc: inferenceSvc,
|
||||
wsHub: wsHub,
|
||||
}
|
||||
@ -107,6 +111,39 @@ func (h *DeviceHandler) DisconnectDevice(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"success": true})
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) FlashDevice(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
var req struct {
|
||||
ModelID string `json:"modelId"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"error": gin.H{"code": "BAD_REQUEST", "message": "modelId is required"},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
taskID, progressCh, err := h.flashSvc.StartFlash(id, req.ModelID)
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{
|
||||
"success": false,
|
||||
"error": gin.H{"code": "FLASH_FAILED", "message": err.Error()},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Forward progress to WebSocket
|
||||
go func() {
|
||||
room := "flash:" + id
|
||||
for progress := range progressCh {
|
||||
h.wsHub.BroadcastToRoom(room, progress)
|
||||
}
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}})
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) StartInference(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
resultCh := make(chan *driver.InferenceResult, 10)
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"visiona-local/server/internal/api/ws"
|
||||
"visiona-local/server/internal/camera"
|
||||
"visiona-local/server/internal/device"
|
||||
"visiona-local/server/internal/flash"
|
||||
"visiona-local/server/internal/inference"
|
||||
"visiona-local/server/internal/model"
|
||||
"visiona-local/server/pkg/logger"
|
||||
@ -23,6 +24,7 @@ func NewRouter(
|
||||
modelStore *model.ModelStore,
|
||||
deviceMgr *device.Manager,
|
||||
cameraMgr *camera.Manager,
|
||||
flashSvc *flash.Service,
|
||||
inferenceSvc *inference.Service,
|
||||
wsHub *ws.Hub,
|
||||
staticFS http.FileSystem,
|
||||
@ -38,7 +40,7 @@ func NewRouter(
|
||||
|
||||
modelHandler := handlers.NewModelHandler(modelRepo)
|
||||
modelUploadHandler := handlers.NewModelUploadHandler(modelRepo, modelStore)
|
||||
deviceHandler := handlers.NewDeviceHandler(deviceMgr, inferenceSvc, wsHub)
|
||||
deviceHandler := handlers.NewDeviceHandler(deviceMgr, flashSvc, inferenceSvc, wsHub)
|
||||
cameraHandler := handlers.NewCameraHandler(cameraMgr, deviceMgr, inferenceSvc, wsHub)
|
||||
|
||||
api := r.Group("/api")
|
||||
@ -63,6 +65,7 @@ func NewRouter(
|
||||
api.GET("/devices/:id", deviceHandler.GetDevice)
|
||||
api.POST("/devices/:id/connect", deviceHandler.ConnectDevice)
|
||||
api.POST("/devices/:id/disconnect", deviceHandler.DisconnectDevice)
|
||||
api.POST("/devices/:id/flash", deviceHandler.FlashDevice)
|
||||
api.POST("/devices/:id/inference/start", deviceHandler.StartInference)
|
||||
api.POST("/devices/:id/inference/stop", deviceHandler.StopInference)
|
||||
|
||||
@ -83,6 +86,7 @@ func NewRouter(
|
||||
|
||||
// WebSocket
|
||||
r.GET("/ws/devices/events", ws.DeviceEventsHandler(wsHub, deviceMgr))
|
||||
r.GET("/ws/devices/:id/flash-progress", ws.FlashProgressHandler(wsHub))
|
||||
r.GET("/ws/devices/:id/inference", ws.InferenceHandler(wsHub, inferenceSvc))
|
||||
r.GET("/ws/server-logs", ws.ServerLogsHandler(wsHub, logBroadcaster))
|
||||
|
||||
|
||||
39
local-tool/server/internal/api/ws/flash_ws.go
Normal file
39
local-tool/server/internal/api/ws/flash_ws.go
Normal file
@ -0,0 +1,39 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
func FlashProgressHandler(hub *Hub) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
deviceID := c.Param("id")
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := &Client{Conn: conn, Send: make(chan []byte, 20)}
|
||||
room := "flash:" + deviceID
|
||||
sub := &Subscription{Client: client, Room: room}
|
||||
hub.RegisterSync(sub)
|
||||
defer hub.Unregister(sub)
|
||||
|
||||
// Read pump — drain incoming messages (ping/pong, close frames)
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for msg := range client.Send {
|
||||
if err := conn.WriteMessage(websocket.TextMessage, msg); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
51
local-tool/server/internal/flash/progress.go
Normal file
51
local-tool/server/internal/flash/progress.go
Normal file
@ -0,0 +1,51 @@
|
||||
package flash
|
||||
|
||||
import (
|
||||
"visiona-local/server/internal/driver"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type FlashTask struct {
|
||||
ID string
|
||||
DeviceID string
|
||||
ModelID string
|
||||
ProgressCh chan driver.FlashProgress
|
||||
Done bool
|
||||
}
|
||||
|
||||
type ProgressTracker struct {
|
||||
tasks map[string]*FlashTask
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewProgressTracker() *ProgressTracker {
|
||||
return &ProgressTracker{
|
||||
tasks: make(map[string]*FlashTask),
|
||||
}
|
||||
}
|
||||
|
||||
func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask {
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
task := &FlashTask{
|
||||
ID: taskID,
|
||||
DeviceID: deviceID,
|
||||
ModelID: modelID,
|
||||
ProgressCh: make(chan driver.FlashProgress, 20),
|
||||
}
|
||||
pt.tasks[taskID] = task
|
||||
return task
|
||||
}
|
||||
|
||||
func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) {
|
||||
pt.mu.RLock()
|
||||
defer pt.mu.RUnlock()
|
||||
t, ok := pt.tasks[taskID]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
func (pt *ProgressTracker) Remove(taskID string) {
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
delete(pt.tasks, taskID)
|
||||
}
|
||||
140
local-tool/server/internal/flash/service.go
Normal file
140
local-tool/server/internal/flash/service.go
Normal file
@ -0,0 +1,140 @@
|
||||
package flash
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"visiona-local/server/internal/device"
|
||||
"visiona-local/server/internal/driver"
|
||||
"visiona-local/server/internal/model"
|
||||
)
|
||||
|
||||
// isCompatible checks if any of the model's supported hardware types match
|
||||
// the device type. The match is case-insensitive and also checks if the
|
||||
// device type string contains the hardware name (e.g. "kneron_kl720" contains "KL720").
|
||||
func isCompatible(modelHardware []string, deviceType string) bool {
|
||||
dt := strings.ToUpper(deviceType)
|
||||
for _, hw := range modelHardware {
|
||||
if strings.ToUpper(hw) == dt || strings.Contains(dt, strings.ToUpper(hw)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveModelPath checks if a chip-specific NEF file exists for the given
|
||||
// model. For cross-platform models whose filePath points to a KL520 NEF,
|
||||
// this tries to find the equivalent KL720 NEF (and vice versa).
|
||||
//
|
||||
// Resolution: data/nef/kl520/kl520_20001_... → data/nef/kl720/kl720_20001_...
|
||||
func resolveModelPath(filePath string, deviceType string) string {
|
||||
if filePath == "" {
|
||||
return filePath
|
||||
}
|
||||
|
||||
targetChip := ""
|
||||
if strings.Contains(strings.ToLower(deviceType), "kl720") {
|
||||
targetChip = "kl720"
|
||||
} else if strings.Contains(strings.ToLower(deviceType), "kl520") {
|
||||
targetChip = "kl520"
|
||||
}
|
||||
if targetChip == "" {
|
||||
return filePath
|
||||
}
|
||||
|
||||
// Already points to the target chip directory — use as-is.
|
||||
if strings.Contains(filePath, "/"+targetChip+"/") {
|
||||
return filePath
|
||||
}
|
||||
|
||||
// Try to swap chip prefix in both directory and filename.
|
||||
dir := filepath.Dir(filePath)
|
||||
base := filepath.Base(filePath)
|
||||
|
||||
sourceChip := ""
|
||||
if strings.Contains(dir, "kl520") {
|
||||
sourceChip = "kl520"
|
||||
} else if strings.Contains(dir, "kl720") {
|
||||
sourceChip = "kl720"
|
||||
}
|
||||
|
||||
if sourceChip != "" && sourceChip != targetChip {
|
||||
newDir := strings.Replace(dir, sourceChip, targetChip, 1)
|
||||
newBase := strings.Replace(base, sourceChip, targetChip, 1)
|
||||
candidate := filepath.Join(newDir, newBase)
|
||||
if _, err := os.Stat(candidate); err == nil {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
|
||||
return filePath
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
deviceMgr *device.Manager
|
||||
modelRepo *model.Repository
|
||||
tracker *ProgressTracker
|
||||
}
|
||||
|
||||
func NewService(deviceMgr *device.Manager, modelRepo *model.Repository) *Service {
|
||||
return &Service{
|
||||
deviceMgr: deviceMgr,
|
||||
modelRepo: modelRepo,
|
||||
tracker: NewProgressTracker(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) {
|
||||
session, err := s.deviceMgr.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("device not found: %w", err)
|
||||
}
|
||||
if !session.Driver.IsConnected() {
|
||||
return "", nil, fmt.Errorf("device not connected")
|
||||
}
|
||||
|
||||
m, err := s.modelRepo.GetByID(modelID)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
// Check hardware compatibility
|
||||
deviceInfo := session.Driver.Info()
|
||||
if !isCompatible(m.SupportedHardware, deviceInfo.Type) {
|
||||
return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type)
|
||||
}
|
||||
|
||||
// Use the model's .nef file path if available, otherwise fall back to modelID.
|
||||
modelPath := m.FilePath
|
||||
if modelPath == "" {
|
||||
modelPath = modelID
|
||||
}
|
||||
|
||||
// Resolve chip-specific NEF (e.g. KL520 path → KL720 equivalent).
|
||||
modelPath = resolveModelPath(modelPath, deviceInfo.Type)
|
||||
|
||||
taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID)
|
||||
task := s.tracker.Create(taskID, deviceID, modelID)
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
task.Done = true
|
||||
close(task.ProgressCh)
|
||||
}()
|
||||
// Brief pause to allow the WebSocket client to connect before
|
||||
// progress messages start flowing.
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := session.Driver.Flash(modelPath, task.ProgressCh); err != nil {
|
||||
task.ProgressCh <- driver.FlashProgress{
|
||||
Percent: -1,
|
||||
Stage: "error",
|
||||
Error: err.Error(),
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return taskID, task.ProgressCh, nil
|
||||
}
|
||||
@ -23,6 +23,7 @@ import (
|
||||
"visiona-local/server/internal/config"
|
||||
"visiona-local/server/internal/deps"
|
||||
"visiona-local/server/internal/device"
|
||||
"visiona-local/server/internal/flash"
|
||||
"visiona-local/server/internal/inference"
|
||||
"visiona-local/server/internal/model"
|
||||
pkglogger "visiona-local/server/pkg/logger"
|
||||
@ -138,6 +139,7 @@ func main() {
|
||||
cameraMgr := camera.NewManager(cfg.MockCamera)
|
||||
|
||||
// Initialize services
|
||||
flashSvc := flash.NewService(deviceMgr, modelRepo)
|
||||
inferenceSvc := inference.NewService(deviceMgr)
|
||||
|
||||
// Determine static file system for embedded frontend
|
||||
@ -183,7 +185,7 @@ func main() {
|
||||
systemHandler := handlers.NewSystemHandler(Version, BuildTime, pythonBinForSystem, restartFn)
|
||||
|
||||
// Create router
|
||||
r := api.NewRouter(modelRepo, modelStore, deviceMgr, cameraMgr, inferenceSvc, wsHub, staticFS, logBroadcaster, systemHandler)
|
||||
r := api.NewRouter(modelRepo, modelStore, deviceMgr, cameraMgr, flashSvc, inferenceSvc, wsHub, staticFS, logBroadcaster, systemHandler)
|
||||
|
||||
// Configure HTTP server (bind to localhost only)
|
||||
addr := cfg.Addr()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user