diff --git a/local-tool/frontend/src/hooks/use-flash-progress.ts b/local-tool/frontend/src/hooks/use-flash-progress.ts index 2c24270..f6da4d7 100644 --- a/local-tool/frontend/src/hooks/use-flash-progress.ts +++ b/local-tool/frontend/src/hooks/use-flash-progress.ts @@ -44,7 +44,7 @@ export function useFlashProgress(deviceId: string) { const ws = createWebSocket( `/ws/devices/${deviceId}/flash-progress`, (data) => { - updateProgress(data as FlashProgress); + updateProgress(deviceId, data as FlashProgress); }, () => { doResolve(); diff --git a/local-tool/frontend/src/stores/flash-store.ts b/local-tool/frontend/src/stores/flash-store.ts index 227300d..96e71aa 100644 --- a/local-tool/frontend/src/stores/flash-store.ts +++ b/local-tool/frontend/src/stores/flash-store.ts @@ -4,26 +4,31 @@ import type { FlashProgress } from '@/types/device'; import { showApiError } from '@/lib/toast'; import { useActivityStore } from './activity-store'; +// M4 fix: flash state 以 deviceId 區分,避免多裝置 UI 互相覆蓋。 +// activeDeviceId 記錄當前正在 flash 的裝置,updateProgress 會比對確保不混。 + interface FlashState { + activeDeviceId: string | null; isFlashing: boolean; progress: FlashProgress | null; error: string | null; lastFlashParams: { deviceId: string; modelId: string } | null; startFlash: (deviceId: string, modelId: string) => Promise; - updateProgress: (progress: FlashProgress) => void; + updateProgress: (deviceId: string, progress: FlashProgress) => void; setError: (error: string) => void; retryFlash: () => Promise; reset: () => void; } export const useFlashStore = create((set, get) => ({ + activeDeviceId: null, isFlashing: false, progress: null, error: null, lastFlashParams: null, startFlash: async (deviceId, modelId) => { - set({ isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } }); + set({ activeDeviceId: deviceId, isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } }); const res = await api.post(`/devices/${deviceId}/flash`, { modelId }); if (!res.success) { const errorMsg = res.error?.message || 'Flash failed'; @@ -35,7 +40,12 @@ export const useFlashStore = create((set, get) => ({ } }, - updateProgress: (progress) => { + updateProgress: (deviceId, progress) => { + // M4 fix: 只更新對應裝置的 progress,忽略其他裝置的 WebSocket 訊息 + if (get().activeDeviceId !== null && get().activeDeviceId !== deviceId) { + return; + } + if (progress.error) { set({ isFlashing: false, error: progress.error }); useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`); @@ -59,6 +69,6 @@ export const useFlashStore = create((set, get) => ({ }, reset: () => { - set({ isFlashing: false, progress: null, error: null, lastFlashParams: null }); + set({ activeDeviceId: null, isFlashing: false, progress: null, error: null, lastFlashParams: null }); }, })); diff --git a/local-tool/server/internal/api/handlers/device_handler.go b/local-tool/server/internal/api/handlers/device_handler.go index 24253cc..afd1a95 100644 --- a/local-tool/server/internal/api/handlers/device_handler.go +++ b/local-tool/server/internal/api/handlers/device_handler.go @@ -133,12 +133,13 @@ func (h *DeviceHandler) FlashDevice(c *gin.Context) { return } - // Forward progress to WebSocket + // Forward progress to WebSocket, then cleanup task (M2 fix) go func() { room := "flash:" + id for progress := range progressCh { h.wsHub.BroadcastToRoom(room, progress) } + h.flashSvc.CleanupTask(taskID) }() c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}}) diff --git a/local-tool/server/internal/api/ws/flash_ws.go b/local-tool/server/internal/api/ws/flash_ws.go index 712aafb..e4e7bad 100644 --- a/local-tool/server/internal/api/ws/flash_ws.go +++ b/local-tool/server/internal/api/ws/flash_ws.go @@ -20,9 +20,8 @@ func FlashProgressHandler(hub *Hub) gin.HandlerFunc { hub.RegisterSync(sub) defer hub.Unregister(sub) - // Read pump — drain incoming messages (ping/pong, close frames) + // Read pump — drain incoming messages; close handled by outer defer go func() { - defer conn.Close() for { if _, _, err := conn.ReadMessage(); err != nil { break diff --git a/local-tool/server/internal/flash/progress.go b/local-tool/server/internal/flash/progress.go index bd9e3c1..646ba80 100644 --- a/local-tool/server/internal/flash/progress.go +++ b/local-tool/server/internal/flash/progress.go @@ -1,8 +1,9 @@ package flash import ( - "visiona-local/server/internal/driver" "sync" + + "visiona-local/server/internal/driver" ) type FlashTask struct { @@ -15,7 +16,7 @@ type FlashTask struct { type ProgressTracker struct { tasks map[string]*FlashTask - mu sync.RWMutex + mu sync.Mutex } func NewProgressTracker() *ProgressTracker { @@ -24,26 +25,35 @@ func NewProgressTracker() *ProgressTracker { } } +// Create 建立新 flash task。如果同 taskID 已存在且未完成,回傳 nil 表示拒絕。 func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask { pt.mu.Lock() defer pt.mu.Unlock() + + // M3 fix: 防止同裝置重複 flash — 舊 task 未完成就拒絕 + if existing, ok := pt.tasks[taskID]; ok && !existing.Done { + return nil + } + task := &FlashTask{ ID: taskID, DeviceID: deviceID, ModelID: modelID, ProgressCh: make(chan driver.FlashProgress, 20), + Done: false, } pt.tasks[taskID] = task return task } func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) { - pt.mu.RLock() - defer pt.mu.RUnlock() + pt.mu.Lock() + defer pt.mu.Unlock() t, ok := pt.tasks[taskID] return t, ok } +// Remove 清除已完成的 task,釋放 map entry。 func (pt *ProgressTracker) Remove(taskID string) { pt.mu.Lock() defer pt.mu.Unlock() diff --git a/local-tool/server/internal/flash/service.go b/local-tool/server/internal/flash/service.go index 4826bbe..80fa41b 100644 --- a/local-tool/server/internal/flash/service.go +++ b/local-tool/server/internal/flash/service.go @@ -12,9 +12,6 @@ import ( "visiona-local/server/internal/model" ) -// isCompatible checks if any of the model's supported hardware types match -// the device type. The match is case-insensitive and also checks if the -// device type string contains the hardware name (e.g. "kneron_kl720" contains "KL720"). func isCompatible(modelHardware []string, deviceType string) bool { dt := strings.ToUpper(deviceType) for _, hw := range modelHardware { @@ -25,11 +22,6 @@ func isCompatible(modelHardware []string, deviceType string) bool { return false } -// resolveModelPath checks if a chip-specific NEF file exists for the given -// model. For cross-platform models whose filePath points to a KL520 NEF, -// this tries to find the equivalent KL720 NEF (and vice versa). -// -// Resolution: data/nef/kl520/kl520_20001_... → data/nef/kl720/kl720_20001_... func resolveModelPath(filePath string, deviceType string) string { if filePath == "" { return filePath @@ -45,12 +37,10 @@ func resolveModelPath(filePath string, deviceType string) string { return filePath } - // Already points to the target chip directory — use as-is. if strings.Contains(filePath, "/"+targetChip+"/") { return filePath } - // Try to swap chip prefix in both directory and filename. dir := filepath.Dir(filePath) base := filepath.Base(filePath) @@ -87,6 +77,11 @@ func NewService(deviceMgr *device.Manager, modelRepo *model.Repository) *Service } } +// CleanupTask 清除已完成的 flash task(由 handler goroutine 在讀取完 progressCh 後呼叫)。 +func (s *Service) CleanupTask(taskID string) { + s.tracker.Remove(taskID) +} + func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) { session, err := s.deviceMgr.GetDevice(deviceID) if err != nil { @@ -101,39 +96,47 @@ func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.Fl return "", nil, fmt.Errorf("model not found: %w", err) } - // Check hardware compatibility deviceInfo := session.Driver.Info() if !isCompatible(m.SupportedHardware, deviceInfo.Type) { return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type) } - // Use the model's .nef file path if available, otherwise fall back to modelID. modelPath := m.FilePath if modelPath == "" { - modelPath = modelID + return "", nil, fmt.Errorf("model %s has no .nef file path", modelID) } - // Resolve chip-specific NEF (e.g. KL520 path → KL720 equivalent). modelPath = resolveModelPath(modelPath, deviceInfo.Type) taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID) + + // M3 fix: 防止同裝置同模型重複 flash task := s.tracker.Create(taskID, deviceID, modelID) + if task == nil { + return "", nil, fmt.Errorf("flash already in progress for device %s model %s", deviceID, modelID) + } go func() { - defer func() { - task.Done = true - close(task.ProgressCh) - }() - // Brief pause to allow the WebSocket client to connect before - // progress messages start flowing. + // M1 fix: 先跑 driver.Flash,收集 error,最後才寫 error message + close channel。 + // driver.Flash 內部會多次寫入 task.ProgressCh(進度更新),我們不能在它還在寫的時候 close。 + // driver.Flash 返回時保證不會再寫入 progressCh。 + time.Sleep(500 * time.Millisecond) - if err := session.Driver.Flash(modelPath, task.ProgressCh); err != nil { + + flashErr := session.Driver.Flash(modelPath, task.ProgressCh) + + // Flash 完成或失敗後,driver 不會再寫 progressCh,安全地寫 error 訊息然後 close。 + if flashErr != nil { task.ProgressCh <- driver.FlashProgress{ Percent: -1, Stage: "error", - Error: err.Error(), + Error: flashErr.Error(), } } + + task.Done = true + close(task.ProgressCh) + // M2 note: 不在這裡 Remove — 讓 handler 讀完 progressCh 後呼叫 CleanupTask }() return taskID, task.ProgressCh, nil