fix(local-tool): Review M1-M4 + m5 修復 — flash 生命週期 + store 隔離

Review 問題修復:

M1(寫已關閉 channel panic):
- flash service goroutine 改成先等 driver.Flash() 返回,再寫 error 訊息,最後 close
- driver.Flash 返回後保證不再寫 progressCh,消除 race condition

M2(FlashTask 永不清除 memory leak):
- service.go 新增 CleanupTask(taskID) 公開方法
- device_handler.go 的 goroutine 在 `for range progressCh` 結束後呼叫 CleanupTask

M3(同裝置重複 flash taskID 衝突):
- ProgressTracker.Create 改成:舊 task 未完成時返回 nil
- StartFlash 檢查 nil → 回傳 "flash already in progress" 錯誤

M4(前端 flash store 全域不區分 deviceId):
- flash-store.ts 新增 activeDeviceId 欄位
- updateProgress 改接 (deviceId, progress),比對 activeDeviceId 防止混裝
- use-flash-progress.ts 的 WebSocket callback 傳入 deviceId

m5(flash_ws.go 雙重 conn.Close):
- read pump goroutine 移除 defer conn.Close(),由外層 defer 統一關閉

額外修復(S4):
- modelPath 為空時直接回 error 而非傳無效路徑給 driver

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
jim800121chen 2026-04-12 20:16:24 +08:00
parent 44711753ae
commit 3c6971febd
6 changed files with 57 additions and 34 deletions

View File

@ -44,7 +44,7 @@ export function useFlashProgress(deviceId: string) {
const ws = createWebSocket(
`/ws/devices/${deviceId}/flash-progress`,
(data) => {
updateProgress(data as FlashProgress);
updateProgress(deviceId, data as FlashProgress);
},
() => {
doResolve();

View File

@ -4,26 +4,31 @@ import type { FlashProgress } from '@/types/device';
import { showApiError } from '@/lib/toast';
import { useActivityStore } from './activity-store';
// M4 fix: flash state 以 deviceId 區分,避免多裝置 UI 互相覆蓋。
// activeDeviceId 記錄當前正在 flash 的裝置updateProgress 會比對確保不混。
interface FlashState {
activeDeviceId: string | null;
isFlashing: boolean;
progress: FlashProgress | null;
error: string | null;
lastFlashParams: { deviceId: string; modelId: string } | null;
startFlash: (deviceId: string, modelId: string) => Promise<void>;
updateProgress: (progress: FlashProgress) => void;
updateProgress: (deviceId: string, progress: FlashProgress) => void;
setError: (error: string) => void;
retryFlash: () => Promise<void>;
reset: () => void;
}
export const useFlashStore = create<FlashState>((set, get) => ({
activeDeviceId: null,
isFlashing: false,
progress: null,
error: null,
lastFlashParams: null,
startFlash: async (deviceId, modelId) => {
set({ isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
set({ activeDeviceId: deviceId, isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
const res = await api.post(`/devices/${deviceId}/flash`, { modelId });
if (!res.success) {
const errorMsg = res.error?.message || 'Flash failed';
@ -35,7 +40,12 @@ export const useFlashStore = create<FlashState>((set, get) => ({
}
},
updateProgress: (progress) => {
updateProgress: (deviceId, progress) => {
// M4 fix: 只更新對應裝置的 progress忽略其他裝置的 WebSocket 訊息
if (get().activeDeviceId !== null && get().activeDeviceId !== deviceId) {
return;
}
if (progress.error) {
set({ isFlashing: false, error: progress.error });
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`);
@ -59,6 +69,6 @@ export const useFlashStore = create<FlashState>((set, get) => ({
},
reset: () => {
set({ isFlashing: false, progress: null, error: null, lastFlashParams: null });
set({ activeDeviceId: null, isFlashing: false, progress: null, error: null, lastFlashParams: null });
},
}));

View File

@ -133,12 +133,13 @@ func (h *DeviceHandler) FlashDevice(c *gin.Context) {
return
}
// Forward progress to WebSocket
// Forward progress to WebSocket, then cleanup task (M2 fix)
go func() {
room := "flash:" + id
for progress := range progressCh {
h.wsHub.BroadcastToRoom(room, progress)
}
h.flashSvc.CleanupTask(taskID)
}()
c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}})

View File

@ -20,9 +20,8 @@ func FlashProgressHandler(hub *Hub) gin.HandlerFunc {
hub.RegisterSync(sub)
defer hub.Unregister(sub)
// Read pump — drain incoming messages (ping/pong, close frames)
// Read pump — drain incoming messages; close handled by outer defer
go func() {
defer conn.Close()
for {
if _, _, err := conn.ReadMessage(); err != nil {
break

View File

@ -1,8 +1,9 @@
package flash
import (
"visiona-local/server/internal/driver"
"sync"
"visiona-local/server/internal/driver"
)
type FlashTask struct {
@ -15,7 +16,7 @@ type FlashTask struct {
type ProgressTracker struct {
tasks map[string]*FlashTask
mu sync.RWMutex
mu sync.Mutex
}
func NewProgressTracker() *ProgressTracker {
@ -24,26 +25,35 @@ func NewProgressTracker() *ProgressTracker {
}
}
// Create 建立新 flash task。如果同 taskID 已存在且未完成,回傳 nil 表示拒絕。
func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask {
pt.mu.Lock()
defer pt.mu.Unlock()
// M3 fix: 防止同裝置重複 flash — 舊 task 未完成就拒絕
if existing, ok := pt.tasks[taskID]; ok && !existing.Done {
return nil
}
task := &FlashTask{
ID: taskID,
DeviceID: deviceID,
ModelID: modelID,
ProgressCh: make(chan driver.FlashProgress, 20),
Done: false,
}
pt.tasks[taskID] = task
return task
}
func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) {
pt.mu.RLock()
defer pt.mu.RUnlock()
pt.mu.Lock()
defer pt.mu.Unlock()
t, ok := pt.tasks[taskID]
return t, ok
}
// Remove 清除已完成的 task釋放 map entry。
func (pt *ProgressTracker) Remove(taskID string) {
pt.mu.Lock()
defer pt.mu.Unlock()

View File

@ -12,9 +12,6 @@ import (
"visiona-local/server/internal/model"
)
// isCompatible checks if any of the model's supported hardware types match
// the device type. The match is case-insensitive and also checks if the
// device type string contains the hardware name (e.g. "kneron_kl720" contains "KL720").
func isCompatible(modelHardware []string, deviceType string) bool {
dt := strings.ToUpper(deviceType)
for _, hw := range modelHardware {
@ -25,11 +22,6 @@ func isCompatible(modelHardware []string, deviceType string) bool {
return false
}
// resolveModelPath checks if a chip-specific NEF file exists for the given
// model. For cross-platform models whose filePath points to a KL520 NEF,
// this tries to find the equivalent KL720 NEF (and vice versa).
//
// Resolution: data/nef/kl520/kl520_20001_... → data/nef/kl720/kl720_20001_...
func resolveModelPath(filePath string, deviceType string) string {
if filePath == "" {
return filePath
@ -45,12 +37,10 @@ func resolveModelPath(filePath string, deviceType string) string {
return filePath
}
// Already points to the target chip directory — use as-is.
if strings.Contains(filePath, "/"+targetChip+"/") {
return filePath
}
// Try to swap chip prefix in both directory and filename.
dir := filepath.Dir(filePath)
base := filepath.Base(filePath)
@ -87,6 +77,11 @@ func NewService(deviceMgr *device.Manager, modelRepo *model.Repository) *Service
}
}
// CleanupTask 清除已完成的 flash task由 handler goroutine 在讀取完 progressCh 後呼叫)。
func (s *Service) CleanupTask(taskID string) {
s.tracker.Remove(taskID)
}
func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) {
session, err := s.deviceMgr.GetDevice(deviceID)
if err != nil {
@ -101,39 +96,47 @@ func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.Fl
return "", nil, fmt.Errorf("model not found: %w", err)
}
// Check hardware compatibility
deviceInfo := session.Driver.Info()
if !isCompatible(m.SupportedHardware, deviceInfo.Type) {
return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type)
}
// Use the model's .nef file path if available, otherwise fall back to modelID.
modelPath := m.FilePath
if modelPath == "" {
modelPath = modelID
return "", nil, fmt.Errorf("model %s has no .nef file path", modelID)
}
// Resolve chip-specific NEF (e.g. KL520 path → KL720 equivalent).
modelPath = resolveModelPath(modelPath, deviceInfo.Type)
taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID)
// M3 fix: 防止同裝置同模型重複 flash
task := s.tracker.Create(taskID, deviceID, modelID)
if task == nil {
return "", nil, fmt.Errorf("flash already in progress for device %s model %s", deviceID, modelID)
}
go func() {
defer func() {
task.Done = true
close(task.ProgressCh)
}()
// Brief pause to allow the WebSocket client to connect before
// progress messages start flowing.
// M1 fix: 先跑 driver.Flash收集 error最後才寫 error message + close channel。
// driver.Flash 內部會多次寫入 task.ProgressCh進度更新我們不能在它還在寫的時候 close。
// driver.Flash 返回時保證不會再寫入 progressCh。
time.Sleep(500 * time.Millisecond)
if err := session.Driver.Flash(modelPath, task.ProgressCh); err != nil {
flashErr := session.Driver.Flash(modelPath, task.ProgressCh)
// Flash 完成或失敗後driver 不會再寫 progressCh安全地寫 error 訊息然後 close。
if flashErr != nil {
task.ProgressCh <- driver.FlashProgress{
Percent: -1,
Stage: "error",
Error: err.Error(),
Error: flashErr.Error(),
}
}
task.Done = true
close(task.ProgressCh)
// M2 note: 不在這裡 Remove — 讓 handler 讀完 progressCh 後呼叫 CleanupTask
}()
return taskID, task.ProgressCh, nil