package session import ( "context" "errors" "net" "sync" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // fakeHandle 是測試用 Handle 實作,不涉及真實網路。 // // 為配合 B2 Review M1 修補,fakeHandle 以 mutex 保護 summary 的 // LastHeartbeat 欄位(Summary() 回傳快照、RecordHeartbeat 在 lock 下寫入)。 type fakeHandle struct { mu sync.Mutex summary Summary closed atomic.Bool closeErr error } func newFakeHandle(token, userID, deviceID string) *fakeHandle { now := time.Now().UTC() return &fakeHandle{ summary: Summary{ Token: token, UserID: userID, DeviceID: deviceID, ConnectedAt: now, LastHeartbeat: now, }, } } func (h *fakeHandle) OpenStream(ctx context.Context) (net.Conn, error) { if h.closed.Load() { return nil, ErrSessionClosed } return nil, errors.New("fakeHandle: OpenStream not implemented for tests") } func (h *fakeHandle) Close() error { h.closed.Store(true) return h.closeErr } func (h *fakeHandle) IsClosed() bool { return h.closed.Load() } func (h *fakeHandle) Summary() *Summary { h.mu.Lock() defer h.mu.Unlock() cp := h.summary return &cp } func (h *fakeHandle) RecordHeartbeat(t time.Time) { h.mu.Lock() defer h.mu.Unlock() h.summary.LastHeartbeat = t } // setLastHeartbeatForTest 僅供測試直接覆寫 LastHeartbeat(CleanupExpired 測試用)。 func (h *fakeHandle) setLastHeartbeatForTest(t time.Time) { h.mu.Lock() defer h.mu.Unlock() h.summary.LastHeartbeat = t } // ---------------------------------------------------------------------- // Tests // ---------------------------------------------------------------------- func TestInMemoryStore_RegisterAndLookup(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() h := newFakeHandle("tok-1", "user-1", "dev-1") require.NoError(t, s.Register(ctx, "tok-1", h)) got, err := s.Lookup(ctx, "tok-1") require.NoError(t, err) assert.Equal(t, h, got) } func TestInMemoryStore_Lookup_NotFound(t *testing.T) { s := NewInMemoryStore() _, err := s.Lookup(context.Background(), "tok-unknown") assert.ErrorIs(t, err, ErrSessionNotFound) } func TestInMemoryStore_Register_OverwritesAndClosesOld(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() old := newFakeHandle("tok-1", "user-1", "dev-1") require.NoError(t, s.Register(ctx, "tok-1", old)) // 後連覆蓋前連(Q5) newHandle := newFakeHandle("tok-1", "user-1", "dev-1") require.NoError(t, s.Register(ctx, "tok-1", newHandle)) // 舊 handle 應被 Close assert.True(t, old.IsClosed(), "舊 handle 應該被 Close") assert.False(t, newHandle.IsClosed(), "新 handle 不應被 Close") // Lookup 回傳新的 got, err := s.Lookup(ctx, "tok-1") require.NoError(t, err) assert.Equal(t, newHandle, got) } func TestInMemoryStore_Exists(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() ok, err := s.Exists(ctx, "tok-1") require.NoError(t, err) assert.False(t, ok) require.NoError(t, s.Register(ctx, "tok-1", newFakeHandle("tok-1", "u", "d"))) ok, err = s.Exists(ctx, "tok-1") require.NoError(t, err) assert.True(t, ok) } func TestInMemoryStore_Unregister(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() require.NoError(t, s.Register(ctx, "tok-1", newFakeHandle("tok-1", "u", "d"))) require.NoError(t, s.Unregister(ctx, "tok-1")) ok, _ := s.Exists(ctx, "tok-1") assert.False(t, ok) // 不存在的 token 不應回錯 assert.NoError(t, s.Unregister(ctx, "tok-unknown")) } func TestInMemoryStore_Heartbeat_UpdatesLastHeartbeat(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() h := newFakeHandle("tok-1", "u", "d") start := h.Summary().LastHeartbeat require.NoError(t, s.Register(ctx, "tok-1", h)) // 確保時間差 time.Sleep(2 * time.Millisecond) require.NoError(t, s.Heartbeat(ctx, "tok-1")) after := h.Summary().LastHeartbeat assert.True(t, after.After(start), "LastHeartbeat 應該被更新:%v > %v", after, start) // (修 B2 M1)Heartbeat 走 RecordHeartbeat;race detector 必須通過。 } func TestInMemoryStore_Heartbeat_NotFound(t *testing.T) { s := NewInMemoryStore() err := s.Heartbeat(context.Background(), "tok-unknown") assert.ErrorIs(t, err, ErrSessionNotFound) } func TestInMemoryStore_List(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() require.NoError(t, s.Register(ctx, "a", newFakeHandle("a", "u1", "d1"))) require.NoError(t, s.Register(ctx, "b", newFakeHandle("b", "u2", "d2"))) summaries, err := s.List(ctx) require.NoError(t, err) assert.Len(t, summaries, 2) tokens := map[string]bool{} for _, sum := range summaries { tokens[sum.Token] = true } assert.True(t, tokens["a"]) assert.True(t, tokens["b"]) } func TestInMemoryStore_CleanupExpired(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() // 手動設定 LastHeartbeat 為過去時間 old := newFakeHandle("expired", "u", "d") old.setLastHeartbeatForTest(time.Now().UTC().Add(-1 * time.Minute)) fresh := newFakeHandle("fresh", "u", "d") // fresh.summary.LastHeartbeat 已在 newFakeHandle 設為 now require.NoError(t, s.Register(ctx, "expired", old)) require.NoError(t, s.Register(ctx, "fresh", fresh)) // 以 30s 為 expireAfter,expired 超過 60s 應被清 removed, err := s.CleanupExpired(ctx, 30*time.Second) require.NoError(t, err) assert.Equal(t, 1, removed) assert.True(t, old.IsClosed(), "逾時的 handle 應該被 Close") assert.False(t, fresh.IsClosed()) ok, _ := s.Exists(ctx, "expired") assert.False(t, ok) ok, _ = s.Exists(ctx, "fresh") assert.True(t, ok) } func TestInMemoryStore_CleanupExpired_Empty(t *testing.T) { s := NewInMemoryStore() removed, err := s.CleanupExpired(context.Background(), 30*time.Second) require.NoError(t, err) assert.Equal(t, 0, removed) } // TestInMemoryStore_Heartbeat_CleanupExpired_NoRace 驗證 B2 Review M1 修補: // 並發執行 Heartbeat(寫)與 CleanupExpired / List(讀)時 race detector 不應捕捉到衝突。 // 本測試應在 `go test -race` 下通過。 func TestInMemoryStore_Heartbeat_CleanupExpired_NoRace(t *testing.T) { ctx := context.Background() s := NewInMemoryStore() // 註冊 20 個 session const n = 20 for i := 0; i < n; i++ { tok := "tok-" + string(rune('a'+i)) require.NoError(t, s.Register(ctx, tok, newFakeHandle(tok, "u", "d"))) } stop := make(chan struct{}) var wg sync.WaitGroup // 並發跑 Heartbeat wg.Add(1) go func() { defer wg.Done() for { select { case <-stop: return default: } for i := 0; i < n; i++ { tok := "tok-" + string(rune('a'+i)) _ = s.Heartbeat(ctx, tok) } } }() // 並發跑 CleanupExpired(不真的清掉,因為 expireAfter 很大) wg.Add(1) go func() { defer wg.Done() for { select { case <-stop: return default: } _, _ = s.CleanupExpired(ctx, 1*time.Hour) } }() // 並發跑 List wg.Add(1) go func() { defer wg.Done() for { select { case <-stop: return default: } _, _ = s.List(ctx) } }() // 跑 100ms 讓 race detector 有足夠機會採樣 time.Sleep(100 * time.Millisecond) close(stop) wg.Wait() }