fix(local-tool): Review M1-M4 + m5 修復 — flash 生命週期 + store 隔離
Review 問題修復: M1(寫已關閉 channel panic): - flash service goroutine 改成先等 driver.Flash() 返回,再寫 error 訊息,最後 close - driver.Flash 返回後保證不再寫 progressCh,消除 race condition M2(FlashTask 永不清除 memory leak): - service.go 新增 CleanupTask(taskID) 公開方法 - device_handler.go 的 goroutine 在 `for range progressCh` 結束後呼叫 CleanupTask M3(同裝置重複 flash taskID 衝突): - ProgressTracker.Create 改成:舊 task 未完成時返回 nil - StartFlash 檢查 nil → 回傳 "flash already in progress" 錯誤 M4(前端 flash store 全域不區分 deviceId): - flash-store.ts 新增 activeDeviceId 欄位 - updateProgress 改接 (deviceId, progress),比對 activeDeviceId 防止混裝 - use-flash-progress.ts 的 WebSocket callback 傳入 deviceId m5(flash_ws.go 雙重 conn.Close): - read pump goroutine 移除 defer conn.Close(),由外層 defer 統一關閉 額外修復(S4): - modelPath 為空時直接回 error 而非傳無效路徑給 driver Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
44711753ae
commit
3c6971febd
@ -44,7 +44,7 @@ export function useFlashProgress(deviceId: string) {
|
||||
const ws = createWebSocket(
|
||||
`/ws/devices/${deviceId}/flash-progress`,
|
||||
(data) => {
|
||||
updateProgress(data as FlashProgress);
|
||||
updateProgress(deviceId, data as FlashProgress);
|
||||
},
|
||||
() => {
|
||||
doResolve();
|
||||
|
||||
@ -4,26 +4,31 @@ import type { FlashProgress } from '@/types/device';
|
||||
import { showApiError } from '@/lib/toast';
|
||||
import { useActivityStore } from './activity-store';
|
||||
|
||||
// M4 fix: flash state 以 deviceId 區分,避免多裝置 UI 互相覆蓋。
|
||||
// activeDeviceId 記錄當前正在 flash 的裝置,updateProgress 會比對確保不混。
|
||||
|
||||
interface FlashState {
|
||||
activeDeviceId: string | null;
|
||||
isFlashing: boolean;
|
||||
progress: FlashProgress | null;
|
||||
error: string | null;
|
||||
lastFlashParams: { deviceId: string; modelId: string } | null;
|
||||
startFlash: (deviceId: string, modelId: string) => Promise<void>;
|
||||
updateProgress: (progress: FlashProgress) => void;
|
||||
updateProgress: (deviceId: string, progress: FlashProgress) => void;
|
||||
setError: (error: string) => void;
|
||||
retryFlash: () => Promise<void>;
|
||||
reset: () => void;
|
||||
}
|
||||
|
||||
export const useFlashStore = create<FlashState>((set, get) => ({
|
||||
activeDeviceId: null,
|
||||
isFlashing: false,
|
||||
progress: null,
|
||||
error: null,
|
||||
lastFlashParams: null,
|
||||
|
||||
startFlash: async (deviceId, modelId) => {
|
||||
set({ isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
|
||||
set({ activeDeviceId: deviceId, isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
|
||||
const res = await api.post(`/devices/${deviceId}/flash`, { modelId });
|
||||
if (!res.success) {
|
||||
const errorMsg = res.error?.message || 'Flash failed';
|
||||
@ -35,7 +40,12 @@ export const useFlashStore = create<FlashState>((set, get) => ({
|
||||
}
|
||||
},
|
||||
|
||||
updateProgress: (progress) => {
|
||||
updateProgress: (deviceId, progress) => {
|
||||
// M4 fix: 只更新對應裝置的 progress,忽略其他裝置的 WebSocket 訊息
|
||||
if (get().activeDeviceId !== null && get().activeDeviceId !== deviceId) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (progress.error) {
|
||||
set({ isFlashing: false, error: progress.error });
|
||||
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`);
|
||||
@ -59,6 +69,6 @@ export const useFlashStore = create<FlashState>((set, get) => ({
|
||||
},
|
||||
|
||||
reset: () => {
|
||||
set({ isFlashing: false, progress: null, error: null, lastFlashParams: null });
|
||||
set({ activeDeviceId: null, isFlashing: false, progress: null, error: null, lastFlashParams: null });
|
||||
},
|
||||
}));
|
||||
|
||||
@ -133,12 +133,13 @@ func (h *DeviceHandler) FlashDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Forward progress to WebSocket
|
||||
// Forward progress to WebSocket, then cleanup task (M2 fix)
|
||||
go func() {
|
||||
room := "flash:" + id
|
||||
for progress := range progressCh {
|
||||
h.wsHub.BroadcastToRoom(room, progress)
|
||||
}
|
||||
h.flashSvc.CleanupTask(taskID)
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}})
|
||||
|
||||
@ -20,9 +20,8 @@ func FlashProgressHandler(hub *Hub) gin.HandlerFunc {
|
||||
hub.RegisterSync(sub)
|
||||
defer hub.Unregister(sub)
|
||||
|
||||
// Read pump — drain incoming messages (ping/pong, close frames)
|
||||
// Read pump — drain incoming messages; close handled by outer defer
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
for {
|
||||
if _, _, err := conn.ReadMessage(); err != nil {
|
||||
break
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
package flash
|
||||
|
||||
import (
|
||||
"visiona-local/server/internal/driver"
|
||||
"sync"
|
||||
|
||||
"visiona-local/server/internal/driver"
|
||||
)
|
||||
|
||||
type FlashTask struct {
|
||||
@ -15,7 +16,7 @@ type FlashTask struct {
|
||||
|
||||
type ProgressTracker struct {
|
||||
tasks map[string]*FlashTask
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewProgressTracker() *ProgressTracker {
|
||||
@ -24,26 +25,35 @@ func NewProgressTracker() *ProgressTracker {
|
||||
}
|
||||
}
|
||||
|
||||
// Create 建立新 flash task。如果同 taskID 已存在且未完成,回傳 nil 表示拒絕。
|
||||
func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask {
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
|
||||
// M3 fix: 防止同裝置重複 flash — 舊 task 未完成就拒絕
|
||||
if existing, ok := pt.tasks[taskID]; ok && !existing.Done {
|
||||
return nil
|
||||
}
|
||||
|
||||
task := &FlashTask{
|
||||
ID: taskID,
|
||||
DeviceID: deviceID,
|
||||
ModelID: modelID,
|
||||
ProgressCh: make(chan driver.FlashProgress, 20),
|
||||
Done: false,
|
||||
}
|
||||
pt.tasks[taskID] = task
|
||||
return task
|
||||
}
|
||||
|
||||
func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) {
|
||||
pt.mu.RLock()
|
||||
defer pt.mu.RUnlock()
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
t, ok := pt.tasks[taskID]
|
||||
return t, ok
|
||||
}
|
||||
|
||||
// Remove 清除已完成的 task,釋放 map entry。
|
||||
func (pt *ProgressTracker) Remove(taskID string) {
|
||||
pt.mu.Lock()
|
||||
defer pt.mu.Unlock()
|
||||
|
||||
@ -12,9 +12,6 @@ import (
|
||||
"visiona-local/server/internal/model"
|
||||
)
|
||||
|
||||
// isCompatible checks if any of the model's supported hardware types match
|
||||
// the device type. The match is case-insensitive and also checks if the
|
||||
// device type string contains the hardware name (e.g. "kneron_kl720" contains "KL720").
|
||||
func isCompatible(modelHardware []string, deviceType string) bool {
|
||||
dt := strings.ToUpper(deviceType)
|
||||
for _, hw := range modelHardware {
|
||||
@ -25,11 +22,6 @@ func isCompatible(modelHardware []string, deviceType string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// resolveModelPath checks if a chip-specific NEF file exists for the given
|
||||
// model. For cross-platform models whose filePath points to a KL520 NEF,
|
||||
// this tries to find the equivalent KL720 NEF (and vice versa).
|
||||
//
|
||||
// Resolution: data/nef/kl520/kl520_20001_... → data/nef/kl720/kl720_20001_...
|
||||
func resolveModelPath(filePath string, deviceType string) string {
|
||||
if filePath == "" {
|
||||
return filePath
|
||||
@ -45,12 +37,10 @@ func resolveModelPath(filePath string, deviceType string) string {
|
||||
return filePath
|
||||
}
|
||||
|
||||
// Already points to the target chip directory — use as-is.
|
||||
if strings.Contains(filePath, "/"+targetChip+"/") {
|
||||
return filePath
|
||||
}
|
||||
|
||||
// Try to swap chip prefix in both directory and filename.
|
||||
dir := filepath.Dir(filePath)
|
||||
base := filepath.Base(filePath)
|
||||
|
||||
@ -87,6 +77,11 @@ func NewService(deviceMgr *device.Manager, modelRepo *model.Repository) *Service
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupTask 清除已完成的 flash task(由 handler goroutine 在讀取完 progressCh 後呼叫)。
|
||||
func (s *Service) CleanupTask(taskID string) {
|
||||
s.tracker.Remove(taskID)
|
||||
}
|
||||
|
||||
func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) {
|
||||
session, err := s.deviceMgr.GetDevice(deviceID)
|
||||
if err != nil {
|
||||
@ -101,39 +96,47 @@ func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.Fl
|
||||
return "", nil, fmt.Errorf("model not found: %w", err)
|
||||
}
|
||||
|
||||
// Check hardware compatibility
|
||||
deviceInfo := session.Driver.Info()
|
||||
if !isCompatible(m.SupportedHardware, deviceInfo.Type) {
|
||||
return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type)
|
||||
}
|
||||
|
||||
// Use the model's .nef file path if available, otherwise fall back to modelID.
|
||||
modelPath := m.FilePath
|
||||
if modelPath == "" {
|
||||
modelPath = modelID
|
||||
return "", nil, fmt.Errorf("model %s has no .nef file path", modelID)
|
||||
}
|
||||
|
||||
// Resolve chip-specific NEF (e.g. KL520 path → KL720 equivalent).
|
||||
modelPath = resolveModelPath(modelPath, deviceInfo.Type)
|
||||
|
||||
taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID)
|
||||
|
||||
// M3 fix: 防止同裝置同模型重複 flash
|
||||
task := s.tracker.Create(taskID, deviceID, modelID)
|
||||
if task == nil {
|
||||
return "", nil, fmt.Errorf("flash already in progress for device %s model %s", deviceID, modelID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
task.Done = true
|
||||
close(task.ProgressCh)
|
||||
}()
|
||||
// Brief pause to allow the WebSocket client to connect before
|
||||
// progress messages start flowing.
|
||||
// M1 fix: 先跑 driver.Flash,收集 error,最後才寫 error message + close channel。
|
||||
// driver.Flash 內部會多次寫入 task.ProgressCh(進度更新),我們不能在它還在寫的時候 close。
|
||||
// driver.Flash 返回時保證不會再寫入 progressCh。
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if err := session.Driver.Flash(modelPath, task.ProgressCh); err != nil {
|
||||
|
||||
flashErr := session.Driver.Flash(modelPath, task.ProgressCh)
|
||||
|
||||
// Flash 完成或失敗後,driver 不會再寫 progressCh,安全地寫 error 訊息然後 close。
|
||||
if flashErr != nil {
|
||||
task.ProgressCh <- driver.FlashProgress{
|
||||
Percent: -1,
|
||||
Stage: "error",
|
||||
Error: err.Error(),
|
||||
Error: flashErr.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
task.Done = true
|
||||
close(task.ProgressCh)
|
||||
// M2 note: 不在這裡 Remove — 讓 handler 讀完 progressCh 後呼叫 CleanupTask
|
||||
}()
|
||||
|
||||
return taskID, task.ProgressCh, nil
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user