fix(local-tool): Review M1-M4 + m5 修復 — flash 生命週期 + store 隔離
Review 問題修復: M1(寫已關閉 channel panic): - flash service goroutine 改成先等 driver.Flash() 返回,再寫 error 訊息,最後 close - driver.Flash 返回後保證不再寫 progressCh,消除 race condition M2(FlashTask 永不清除 memory leak): - service.go 新增 CleanupTask(taskID) 公開方法 - device_handler.go 的 goroutine 在 `for range progressCh` 結束後呼叫 CleanupTask M3(同裝置重複 flash taskID 衝突): - ProgressTracker.Create 改成:舊 task 未完成時返回 nil - StartFlash 檢查 nil → 回傳 "flash already in progress" 錯誤 M4(前端 flash store 全域不區分 deviceId): - flash-store.ts 新增 activeDeviceId 欄位 - updateProgress 改接 (deviceId, progress),比對 activeDeviceId 防止混裝 - use-flash-progress.ts 的 WebSocket callback 傳入 deviceId m5(flash_ws.go 雙重 conn.Close): - read pump goroutine 移除 defer conn.Close(),由外層 defer 統一關閉 額外修復(S4): - modelPath 為空時直接回 error 而非傳無效路徑給 driver Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
44711753ae
commit
3c6971febd
@ -44,7 +44,7 @@ export function useFlashProgress(deviceId: string) {
|
|||||||
const ws = createWebSocket(
|
const ws = createWebSocket(
|
||||||
`/ws/devices/${deviceId}/flash-progress`,
|
`/ws/devices/${deviceId}/flash-progress`,
|
||||||
(data) => {
|
(data) => {
|
||||||
updateProgress(data as FlashProgress);
|
updateProgress(deviceId, data as FlashProgress);
|
||||||
},
|
},
|
||||||
() => {
|
() => {
|
||||||
doResolve();
|
doResolve();
|
||||||
|
|||||||
@ -4,26 +4,31 @@ import type { FlashProgress } from '@/types/device';
|
|||||||
import { showApiError } from '@/lib/toast';
|
import { showApiError } from '@/lib/toast';
|
||||||
import { useActivityStore } from './activity-store';
|
import { useActivityStore } from './activity-store';
|
||||||
|
|
||||||
|
// M4 fix: flash state 以 deviceId 區分,避免多裝置 UI 互相覆蓋。
|
||||||
|
// activeDeviceId 記錄當前正在 flash 的裝置,updateProgress 會比對確保不混。
|
||||||
|
|
||||||
interface FlashState {
|
interface FlashState {
|
||||||
|
activeDeviceId: string | null;
|
||||||
isFlashing: boolean;
|
isFlashing: boolean;
|
||||||
progress: FlashProgress | null;
|
progress: FlashProgress | null;
|
||||||
error: string | null;
|
error: string | null;
|
||||||
lastFlashParams: { deviceId: string; modelId: string } | null;
|
lastFlashParams: { deviceId: string; modelId: string } | null;
|
||||||
startFlash: (deviceId: string, modelId: string) => Promise<void>;
|
startFlash: (deviceId: string, modelId: string) => Promise<void>;
|
||||||
updateProgress: (progress: FlashProgress) => void;
|
updateProgress: (deviceId: string, progress: FlashProgress) => void;
|
||||||
setError: (error: string) => void;
|
setError: (error: string) => void;
|
||||||
retryFlash: () => Promise<void>;
|
retryFlash: () => Promise<void>;
|
||||||
reset: () => void;
|
reset: () => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const useFlashStore = create<FlashState>((set, get) => ({
|
export const useFlashStore = create<FlashState>((set, get) => ({
|
||||||
|
activeDeviceId: null,
|
||||||
isFlashing: false,
|
isFlashing: false,
|
||||||
progress: null,
|
progress: null,
|
||||||
error: null,
|
error: null,
|
||||||
lastFlashParams: null,
|
lastFlashParams: null,
|
||||||
|
|
||||||
startFlash: async (deviceId, modelId) => {
|
startFlash: async (deviceId, modelId) => {
|
||||||
set({ isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
|
set({ activeDeviceId: deviceId, isFlashing: true, progress: null, error: null, lastFlashParams: { deviceId, modelId } });
|
||||||
const res = await api.post(`/devices/${deviceId}/flash`, { modelId });
|
const res = await api.post(`/devices/${deviceId}/flash`, { modelId });
|
||||||
if (!res.success) {
|
if (!res.success) {
|
||||||
const errorMsg = res.error?.message || 'Flash failed';
|
const errorMsg = res.error?.message || 'Flash failed';
|
||||||
@ -35,7 +40,12 @@ export const useFlashStore = create<FlashState>((set, get) => ({
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
updateProgress: (progress) => {
|
updateProgress: (deviceId, progress) => {
|
||||||
|
// M4 fix: 只更新對應裝置的 progress,忽略其他裝置的 WebSocket 訊息
|
||||||
|
if (get().activeDeviceId !== null && get().activeDeviceId !== deviceId) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (progress.error) {
|
if (progress.error) {
|
||||||
set({ isFlashing: false, error: progress.error });
|
set({ isFlashing: false, error: progress.error });
|
||||||
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`);
|
useActivityStore.getState().addActivity('flash_error', `Flash failed: ${progress.error}`);
|
||||||
@ -59,6 +69,6 @@ export const useFlashStore = create<FlashState>((set, get) => ({
|
|||||||
},
|
},
|
||||||
|
|
||||||
reset: () => {
|
reset: () => {
|
||||||
set({ isFlashing: false, progress: null, error: null, lastFlashParams: null });
|
set({ activeDeviceId: null, isFlashing: false, progress: null, error: null, lastFlashParams: null });
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
|||||||
@ -133,12 +133,13 @@ func (h *DeviceHandler) FlashDevice(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forward progress to WebSocket
|
// Forward progress to WebSocket, then cleanup task (M2 fix)
|
||||||
go func() {
|
go func() {
|
||||||
room := "flash:" + id
|
room := "flash:" + id
|
||||||
for progress := range progressCh {
|
for progress := range progressCh {
|
||||||
h.wsHub.BroadcastToRoom(room, progress)
|
h.wsHub.BroadcastToRoom(room, progress)
|
||||||
}
|
}
|
||||||
|
h.flashSvc.CleanupTask(taskID)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}})
|
c.JSON(200, gin.H{"success": true, "data": gin.H{"taskId": taskID}})
|
||||||
|
|||||||
@ -20,9 +20,8 @@ func FlashProgressHandler(hub *Hub) gin.HandlerFunc {
|
|||||||
hub.RegisterSync(sub)
|
hub.RegisterSync(sub)
|
||||||
defer hub.Unregister(sub)
|
defer hub.Unregister(sub)
|
||||||
|
|
||||||
// Read pump — drain incoming messages (ping/pong, close frames)
|
// Read pump — drain incoming messages; close handled by outer defer
|
||||||
go func() {
|
go func() {
|
||||||
defer conn.Close()
|
|
||||||
for {
|
for {
|
||||||
if _, _, err := conn.ReadMessage(); err != nil {
|
if _, _, err := conn.ReadMessage(); err != nil {
|
||||||
break
|
break
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
package flash
|
package flash
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"visiona-local/server/internal/driver"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"visiona-local/server/internal/driver"
|
||||||
)
|
)
|
||||||
|
|
||||||
type FlashTask struct {
|
type FlashTask struct {
|
||||||
@ -15,7 +16,7 @@ type FlashTask struct {
|
|||||||
|
|
||||||
type ProgressTracker struct {
|
type ProgressTracker struct {
|
||||||
tasks map[string]*FlashTask
|
tasks map[string]*FlashTask
|
||||||
mu sync.RWMutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProgressTracker() *ProgressTracker {
|
func NewProgressTracker() *ProgressTracker {
|
||||||
@ -24,26 +25,35 @@ func NewProgressTracker() *ProgressTracker {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create 建立新 flash task。如果同 taskID 已存在且未完成,回傳 nil 表示拒絕。
|
||||||
func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask {
|
func (pt *ProgressTracker) Create(taskID, deviceID, modelID string) *FlashTask {
|
||||||
pt.mu.Lock()
|
pt.mu.Lock()
|
||||||
defer pt.mu.Unlock()
|
defer pt.mu.Unlock()
|
||||||
|
|
||||||
|
// M3 fix: 防止同裝置重複 flash — 舊 task 未完成就拒絕
|
||||||
|
if existing, ok := pt.tasks[taskID]; ok && !existing.Done {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
task := &FlashTask{
|
task := &FlashTask{
|
||||||
ID: taskID,
|
ID: taskID,
|
||||||
DeviceID: deviceID,
|
DeviceID: deviceID,
|
||||||
ModelID: modelID,
|
ModelID: modelID,
|
||||||
ProgressCh: make(chan driver.FlashProgress, 20),
|
ProgressCh: make(chan driver.FlashProgress, 20),
|
||||||
|
Done: false,
|
||||||
}
|
}
|
||||||
pt.tasks[taskID] = task
|
pt.tasks[taskID] = task
|
||||||
return task
|
return task
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) {
|
func (pt *ProgressTracker) Get(taskID string) (*FlashTask, bool) {
|
||||||
pt.mu.RLock()
|
pt.mu.Lock()
|
||||||
defer pt.mu.RUnlock()
|
defer pt.mu.Unlock()
|
||||||
t, ok := pt.tasks[taskID]
|
t, ok := pt.tasks[taskID]
|
||||||
return t, ok
|
return t, ok
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove 清除已完成的 task,釋放 map entry。
|
||||||
func (pt *ProgressTracker) Remove(taskID string) {
|
func (pt *ProgressTracker) Remove(taskID string) {
|
||||||
pt.mu.Lock()
|
pt.mu.Lock()
|
||||||
defer pt.mu.Unlock()
|
defer pt.mu.Unlock()
|
||||||
|
|||||||
@ -12,9 +12,6 @@ import (
|
|||||||
"visiona-local/server/internal/model"
|
"visiona-local/server/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// isCompatible checks if any of the model's supported hardware types match
|
|
||||||
// the device type. The match is case-insensitive and also checks if the
|
|
||||||
// device type string contains the hardware name (e.g. "kneron_kl720" contains "KL720").
|
|
||||||
func isCompatible(modelHardware []string, deviceType string) bool {
|
func isCompatible(modelHardware []string, deviceType string) bool {
|
||||||
dt := strings.ToUpper(deviceType)
|
dt := strings.ToUpper(deviceType)
|
||||||
for _, hw := range modelHardware {
|
for _, hw := range modelHardware {
|
||||||
@ -25,11 +22,6 @@ func isCompatible(modelHardware []string, deviceType string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolveModelPath checks if a chip-specific NEF file exists for the given
|
|
||||||
// model. For cross-platform models whose filePath points to a KL520 NEF,
|
|
||||||
// this tries to find the equivalent KL720 NEF (and vice versa).
|
|
||||||
//
|
|
||||||
// Resolution: data/nef/kl520/kl520_20001_... → data/nef/kl720/kl720_20001_...
|
|
||||||
func resolveModelPath(filePath string, deviceType string) string {
|
func resolveModelPath(filePath string, deviceType string) string {
|
||||||
if filePath == "" {
|
if filePath == "" {
|
||||||
return filePath
|
return filePath
|
||||||
@ -45,12 +37,10 @@ func resolveModelPath(filePath string, deviceType string) string {
|
|||||||
return filePath
|
return filePath
|
||||||
}
|
}
|
||||||
|
|
||||||
// Already points to the target chip directory — use as-is.
|
|
||||||
if strings.Contains(filePath, "/"+targetChip+"/") {
|
if strings.Contains(filePath, "/"+targetChip+"/") {
|
||||||
return filePath
|
return filePath
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to swap chip prefix in both directory and filename.
|
|
||||||
dir := filepath.Dir(filePath)
|
dir := filepath.Dir(filePath)
|
||||||
base := filepath.Base(filePath)
|
base := filepath.Base(filePath)
|
||||||
|
|
||||||
@ -87,6 +77,11 @@ func NewService(deviceMgr *device.Manager, modelRepo *model.Repository) *Service
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CleanupTask 清除已完成的 flash task(由 handler goroutine 在讀取完 progressCh 後呼叫)。
|
||||||
|
func (s *Service) CleanupTask(taskID string) {
|
||||||
|
s.tracker.Remove(taskID)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) {
|
func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.FlashProgress, error) {
|
||||||
session, err := s.deviceMgr.GetDevice(deviceID)
|
session, err := s.deviceMgr.GetDevice(deviceID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -101,39 +96,47 @@ func (s *Service) StartFlash(deviceID, modelID string) (string, <-chan driver.Fl
|
|||||||
return "", nil, fmt.Errorf("model not found: %w", err)
|
return "", nil, fmt.Errorf("model not found: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check hardware compatibility
|
|
||||||
deviceInfo := session.Driver.Info()
|
deviceInfo := session.Driver.Info()
|
||||||
if !isCompatible(m.SupportedHardware, deviceInfo.Type) {
|
if !isCompatible(m.SupportedHardware, deviceInfo.Type) {
|
||||||
return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type)
|
return "", nil, fmt.Errorf("model not compatible with device type %s", deviceInfo.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use the model's .nef file path if available, otherwise fall back to modelID.
|
|
||||||
modelPath := m.FilePath
|
modelPath := m.FilePath
|
||||||
if modelPath == "" {
|
if modelPath == "" {
|
||||||
modelPath = modelID
|
return "", nil, fmt.Errorf("model %s has no .nef file path", modelID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve chip-specific NEF (e.g. KL520 path → KL720 equivalent).
|
|
||||||
modelPath = resolveModelPath(modelPath, deviceInfo.Type)
|
modelPath = resolveModelPath(modelPath, deviceInfo.Type)
|
||||||
|
|
||||||
taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID)
|
taskID := fmt.Sprintf("flash-%s-%s", deviceID, modelID)
|
||||||
|
|
||||||
|
// M3 fix: 防止同裝置同模型重複 flash
|
||||||
task := s.tracker.Create(taskID, deviceID, modelID)
|
task := s.tracker.Create(taskID, deviceID, modelID)
|
||||||
|
if task == nil {
|
||||||
|
return "", nil, fmt.Errorf("flash already in progress for device %s model %s", deviceID, modelID)
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer func() {
|
// M1 fix: 先跑 driver.Flash,收集 error,最後才寫 error message + close channel。
|
||||||
task.Done = true
|
// driver.Flash 內部會多次寫入 task.ProgressCh(進度更新),我們不能在它還在寫的時候 close。
|
||||||
close(task.ProgressCh)
|
// driver.Flash 返回時保證不會再寫入 progressCh。
|
||||||
}()
|
|
||||||
// Brief pause to allow the WebSocket client to connect before
|
|
||||||
// progress messages start flowing.
|
|
||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
if err := session.Driver.Flash(modelPath, task.ProgressCh); err != nil {
|
|
||||||
|
flashErr := session.Driver.Flash(modelPath, task.ProgressCh)
|
||||||
|
|
||||||
|
// Flash 完成或失敗後,driver 不會再寫 progressCh,安全地寫 error 訊息然後 close。
|
||||||
|
if flashErr != nil {
|
||||||
task.ProgressCh <- driver.FlashProgress{
|
task.ProgressCh <- driver.FlashProgress{
|
||||||
Percent: -1,
|
Percent: -1,
|
||||||
Stage: "error",
|
Stage: "error",
|
||||||
Error: err.Error(),
|
Error: flashErr.Error(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
task.Done = true
|
||||||
|
close(task.ProgressCh)
|
||||||
|
// M2 note: 不在這裡 Remove — 讓 handler 讀完 progressCh 後呼叫 CleanupTask
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return taskID, task.ProgressCh, nil
|
return taskID, task.ProgressCh, nil
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user