// PostgresSessionTokenStore 是 SessionTokenStore 的 PostgreSQL 持久層實作(DB 接入塊 3)。 // // 與 InMemorySessionTokenStore 實作相同的 SessionTokenStore interface,讓 main.go 在 // dbPool != nil 時無痛切換、呼叫端(internal/api/pairing.go 的 Create / Revoke)一行都不需改。 // // 對齊: // - database.md §2.4(SessionToken struct)、§4(session_tokens 表 schema) // - migrations/0003_create_token_tables.up.sql(session_tokens 表) // // ── 關鍵改動:plaintext → token_hash 當 PK(同 PostgresPairingStore)── // // Get / Revoke 接收 plaintext,內部先 HashToken(plaintext) 再以 hash 查詢。 // 呼叫端統一傳 plaintext(已 grep 確認:pairing.go 的 SessionTokenStore.Create 用回傳的 // plaintext、Revoke(ctx, plaintext) 傳 plaintext;目前無其他 production Get 呼叫端)。 // DB 永不存明文 token(security.md §1.3)。 // // 語意對齊 in-memory(見 session_token.go): // - SessionToken 無 used_at(非一次性)、無 kind。 // - Get 狀態優先序:revoked → expired(與 in-memory 一致);不存在回 ErrInvalidToken。 // - Revoke 冪等:未撤銷 → 寫 revoked_at;已撤銷 → no-op nil;不存在 → ErrInvalidToken。 // - CleanupExpired:DELETE 所有 expires_at < now 的列,回刪除數。 // - parent_token_hash 為稽核鏈欄位(升級來源 pairing token 的 hash),原樣存取。 package auth import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "visiona-backend/internal/db" ) // PostgresSessionTokenStore 是 session token 的 PostgreSQL 持久層實作。 type PostgresSessionTokenStore struct { pool *pgxpool.Pool } // NewPostgresSessionTokenStore 建立一個以 pgxpool 為後端的 SessionTokenStore。 func NewPostgresSessionTokenStore(pool *pgxpool.Pool) *PostgresSessionTokenStore { return &PostgresSessionTokenStore{pool: pool} } // 編譯時檢查:確保 PostgresSessionTokenStore 實作 SessionTokenStore。 var _ SessionTokenStore = (*PostgresSessionTokenStore)(nil) // sessionColumns 是 SELECT 共用欄位清單(順序必須與 scanSessionToken 對齊)。 const sessionColumns = `token_hash, user_id, device_id, parent_token_hash, created_at, expires_at, revoked_at` // Create 產生並保存一個新 session token。 // // ttl <= 0 時 ExpiresAt 保持 NULL(永不過期)。parentTokenHash 可為空(雛形 caller)。 // 回傳的 info.Plaintext 保留原文供 caller 一次性使用(DB 不存)。 func (s *PostgresSessionTokenStore) Create( ctx context.Context, userID, deviceID, parentTokenHash string, ttl time.Duration, ) (string, *SessionToken, error) { plaintext, err := GenerateSessionToken() if err != nil { return "", nil, err } now := time.Now().UTC() info := &SessionToken{ Plaintext: plaintext, TokenHash: HashToken(plaintext), UserID: userID, DeviceID: deviceID, ParentTokenHash: parentTokenHash, CreatedAt: now, } var expiresAt any // nil → DB NULL if ttl > 0 { expires := now.Add(ttl) info.ExpiresAt = &expires expiresAt = expires } // device_id NOT NULL:空字串無法寫進 UUID 欄位,會在此回 DB error(符合 schema 約束)。 // parent_token_hash nullable:空字串轉 NULL。 var parentArg any if parentTokenHash != "" { parentArg = parentTokenHash } const q = `INSERT INTO session_tokens (token_hash, user_id, device_id, parent_token_hash, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6)` if _, err := s.pool.Exec(ctx, q, info.TokenHash, info.UserID, info.DeviceID, parentArg, info.CreatedAt, expiresAt, ); err != nil { return "", nil, fmt.Errorf("auth: pg session Create: %w", err) } return plaintext, info, nil } // Get 依 plaintext 查詢 session token;內部 HashToken 後查。 // // 狀態優先序對齊 in-memory:revoked → expired。不存在回 ErrInvalidToken。 func (s *PostgresSessionTokenStore) Get(ctx context.Context, plaintext string) (*SessionToken, error) { hash := HashToken(plaintext) const q = `SELECT ` + sessionColumns + ` FROM session_tokens WHERE token_hash = $1` row := s.pool.QueryRow(ctx, q, hash) info, err := scanSessionToken(row) if errors.Is(err, pgx.ErrNoRows) { return nil, ErrInvalidToken } if err != nil { return nil, fmt.Errorf("auth: pg session Get: %w", err) } if info.RevokedAt != nil { return nil, ErrTokenRevoked } if info.ExpiresAt != nil && time.Now().UTC().After(*info.ExpiresAt) { return nil, ErrTokenExpired } return info, nil } // Revoke 撤銷 session token(之後 Get 回 ErrTokenRevoked)。 // // 冪等:已撤銷 → no-op nil;不存在 → ErrInvalidToken。 func (s *PostgresSessionTokenStore) Revoke(ctx context.Context, plaintext string) error { hash := HashToken(plaintext) const q = `UPDATE session_tokens SET revoked_at = now() WHERE token_hash = $1 AND revoked_at IS NULL` tag, err := s.pool.Exec(ctx, q, hash) if err != nil { return fmt.Errorf("auth: pg session Revoke: %w", err) } if tag.RowsAffected() == 1 { return nil } var exists bool if err := s.pool.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM session_tokens WHERE token_hash = $1)`, hash, ).Scan(&exists); err != nil { return fmt.Errorf("auth: pg session Revoke exists check: %w", err) } if !exists { return ErrInvalidToken } return nil // 已撤銷 → 冪等 no-op } // CleanupExpired 移除所有已過 expires_at 的 token;回傳移除數量。 // // expires_at IS NULL(永不過期)不會被刪。 func (s *PostgresSessionTokenStore) CleanupExpired(ctx context.Context, now time.Time) (int, error) { const q = `DELETE FROM session_tokens WHERE expires_at IS NOT NULL AND expires_at < $1` tag, err := s.pool.Exec(ctx, q, now.UTC()) if err != nil { return 0, fmt.Errorf("auth: pg session CleanupExpired: %w", err) } return int(tag.RowsAffected()), nil } // RevokeByDeviceTx 撤銷某 device 名下所有「尚未撤銷」的 session token(cascade 撤銷,塊 5.2)。 // // 在傳入的 Querier(pool 或 tx)上跑 `UPDATE ... SET revoked_at = now() WHERE device_id = $1 // AND revoked_at IS NULL`(database.md §6)。回傳實際撤銷的列數(觀測用,無對象回 0、不報錯)。 // // session_tokens.device_id 為 NOT NULL(每個 session token 必綁 device),故同一 device 的所有 // 未撤銷 session token 都會被撤——這正是「刪 device → 該 device 不能再被任何長效 token 觸達」的目的。 func (s *PostgresSessionTokenStore) RevokeByDeviceTx(ctx context.Context, q db.Querier, deviceID string) (int, error) { if deviceID == "" { return 0, nil } const sql = `UPDATE session_tokens SET revoked_at = now() WHERE device_id = $1 AND revoked_at IS NULL` tag, err := q.Exec(ctx, sql, deviceID) if err != nil { return 0, fmt.Errorf("auth: pg session RevokeByDevice: %w", err) } return int(tag.RowsAffected()), nil } // scanSessionToken 從一列掃出 *SessionToken。欄位順序必須與 sessionColumns 對齊。 // // parent_token_hash nullable → 以 *string 接、NULL 掃成空字串(對齊 in-memory zero value)。 // 時間欄位正規化為 UTC。Plaintext 留空(DB 不存)。 func scanSessionToken(row rowScanner) (*SessionToken, error) { var ( t SessionToken parentHash *string ) err := row.Scan( &t.TokenHash, &t.UserID, &t.DeviceID, &parentHash, &t.CreatedAt, &t.ExpiresAt, &t.RevokedAt, ) if err != nil { return nil, err } if parentHash != nil { t.ParentTokenHash = *parentHash } t.CreatedAt = t.CreatedAt.UTC() t.ExpiresAt = utcPtr(t.ExpiresAt) t.RevokedAt = utcPtr(t.RevokedAt) return &t, nil }