feat(visionA-backend): Phase 0 → 0.7 雲端後端(雙 binary + OIDC BFF + stage 部署)
從 edge-ai-platform POC 轉為正式產品的雲端後端,含以下整合階段:
- Phase 0:雛形骨架 — `cmd/api-server` (REST :3721) + `cmd/remote-proxy`
(tunnel :3800 / internal :3801) 雙 binary 共用 internal/,沿用 POC 的
WebSocket+yamux tunnel 協定但解耦 relay 與 API
- Phase 0.6:OIDC BFF 接 Innovedus Member Center
- internal/oidc package(coreos/go-oidc + PKCE S256 + state + nonce)
- internal/usersession package(HMAC-SHA256 cookie + RotateSessionID
防 session fixation, OWASP ASVS V3.2.1)
- 4 個 OIDC handler(/api/auth/login|callback|me|logout)+ AuthMiddleware
- 完全拔除 StaticAuthProvider,OIDC 是唯一認證路徑
- 9 個 ADR(含 ADR-010 BFF / ADR-011 取代 static auth /
ADR-012 pending session shared cookie / ADR-013 PKCE-only public client)
- Phase 0.7:A1 改造 + security audit 修復
- OIDC ClientSecret 變選填,支援 stage MC 的 public PKCE-only client
(AuthStyleInParams 強制 token endpoint 不送 client_secret)
- 預留 ServiceClient* 欄位給未來 client_credentials grant
- 移除 13+ 處 resolveUserID(uc, StaticUserID) fallback 改 strict mode
(Audit C1:multi-tenant 隔離破口)
- Pairing exchange MarkUsed 失敗 abort + revoke session token(Audit M3)
- 新增 all_endpoints_require_auth_test 整合測試(51 endpoint × 401)
驗證:go test -race -count=3 ./... 17 packages 全綠 / go vet 0 warning
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
b71ff4cd3c
commit
22f0837ba8
157
visionA-backend/.env.example
Normal file
157
visionA-backend/.env.example
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
# visionA-backend 環境變數範本
|
||||||
|
#
|
||||||
|
# 使用方式:
|
||||||
|
# cp .env.example .env
|
||||||
|
# # 視情況修改 .env 內的值(尤其 VISIONA_STORAGE_SIGNING_SECRET 與 VISIONA_PAIRING_TOKEN)
|
||||||
|
#
|
||||||
|
# ⚠️ 不要把 .env commit 進 git(已在 .gitignore 中排除)
|
||||||
|
# 相關文件:
|
||||||
|
# - .autoflow/04-architecture/build-deploy.md §9(變數對照表)
|
||||||
|
# - internal/config/config.go(每個欄位的定義)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 共用
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 日誌等級:debug / info / warn / error
|
||||||
|
VISIONA_LOG_LEVEL=info
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# api-server
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 對前端的 REST / WebSocket port(對齊 local-tool 的 base URL 預設)
|
||||||
|
VISIONA_API_PORT=3721
|
||||||
|
|
||||||
|
# api-server 連 remote-proxy 的 internal HTTP base URL
|
||||||
|
# 本機 go run 時用 localhost;docker-compose 內部會被 compose 覆寫為 http://remote-proxy:3801
|
||||||
|
VISIONA_PROXY_INTERNAL_URL=http://localhost:3801
|
||||||
|
|
||||||
|
# Static user — Phase 0.7 security audit 後僅供 dev seed(VISIONA_SEED_DEMO_DATA=true)
|
||||||
|
# 與 unit test fixture 用;不再注入 api.Deps、stage/prod 留空無影響。
|
||||||
|
# 見 .autoflow/05-implementation/review/phase-0.7-security-audit.md C1。
|
||||||
|
VISIONA_STATIC_USER_ID=demo-user
|
||||||
|
|
||||||
|
# 啟動時 seed 示範資料(device + model + pairing token),方便前端 demo
|
||||||
|
VISIONA_SEED_DEMO_DATA=true
|
||||||
|
|
||||||
|
# CORS 白名單(逗號分隔)— 預設允許 frontend dev server(http://localhost:3000)
|
||||||
|
VISIONA_CORS_ALLOWED_ORIGINS=http://localhost:3000
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# OIDC(必填 — OB5 起 OIDC 是唯一認證路徑;A1 起支援 public PKCE-only client)
|
||||||
|
# ============================================================
|
||||||
|
# 必填欄位缺任何一項,main.go 啟動時會 fatal log 退出。
|
||||||
|
#
|
||||||
|
# 對應 Innovedus Member Center 的 OIDC client 設定:
|
||||||
|
# - 在 Member Center 註冊一個 OAuth client(confidential 或 public 皆可)
|
||||||
|
# - 取得 client_id(public client 沒有 client_secret)
|
||||||
|
# - 將 RedirectURL 加入 Member Center 的白名單
|
||||||
|
|
||||||
|
# Member Center 的 issuer(不帶結尾斜線;MC 的 issuer 末尾斜線必要時請保留)
|
||||||
|
# dev: http://localhost:5050
|
||||||
|
# stage: https://stage-9527.innovedus.com:7850/
|
||||||
|
# prod: https://members.innovedus.com
|
||||||
|
VISIONA_OIDC_ISSUER_URL=http://localhost:5050
|
||||||
|
|
||||||
|
# 在 Member Center 註冊的 OAuth client_id
|
||||||
|
VISIONA_OIDC_CLIENT_ID=visiona-cloud
|
||||||
|
|
||||||
|
# Client secret(A1:選填 — public PKCE-only client 留空)
|
||||||
|
# - 有值 → confidential client mode(client_secret + PKCE 雙保險)
|
||||||
|
# - 留空 → public PKCE-only client mode(依靠 PKCE 防 code interception)
|
||||||
|
# ⚠️ 不可 commit;prod 用 Secrets Manager。Stage MC 配的 client `b8093fea...` 是 public,留空。
|
||||||
|
VISIONA_OIDC_CLIENT_SECRET=
|
||||||
|
|
||||||
|
# Backend callback URL — 必須與 Member Center 註冊值完全一致
|
||||||
|
# dev: http://localhost:3721/api/auth/callback
|
||||||
|
# stage: https://stage-9527.innovedus.com:9527/api/auth/callback
|
||||||
|
# prod: https://api.visiona.cloud/api/auth/callback
|
||||||
|
VISIONA_OIDC_REDIRECT_URL=http://localhost:3721/api/auth/callback
|
||||||
|
|
||||||
|
# Frontend base URL — callback 完成後 302 redirect 的目的地
|
||||||
|
# dev: http://localhost:3000
|
||||||
|
# stage: https://stage-9527.innovedus.com:9527
|
||||||
|
# prod: https://app.visiona.cloud
|
||||||
|
VISIONA_FRONTEND_URL=http://localhost:3000
|
||||||
|
|
||||||
|
# Service client(client_credentials grant)— A1 預留欄位,**目前不啟用**。
|
||||||
|
# 將來 visionA-backend 需以服務身份呼叫 MC API 時(例如查詢使用者組織、推送通知)
|
||||||
|
# 才會接這條路。留空代表「不啟用」,main.go 不會 wire。
|
||||||
|
# 對應 Stage 的 service client:<see stage .env.stage>
|
||||||
|
VISIONA_OIDC_SERVICE_CLIENT_ID=
|
||||||
|
VISIONA_OIDC_SERVICE_CLIENT_SECRET=
|
||||||
|
|
||||||
|
# Cookie HMAC 簽章 secret(⚠️ 至少 32 byte 隨機字串;prod 用 openssl rand -hex 32)
|
||||||
|
VISIONA_SESSION_SECRET=CHANGE_ME_TO_RANDOM_64_BYTES_in_production
|
||||||
|
|
||||||
|
# Cookie 設定(dev 預設 host-only / non-secure;prod 改 .visiona.cloud + Secure=true)
|
||||||
|
VISIONA_SESSION_COOKIE_NAME=visiona_session
|
||||||
|
VISIONA_SESSION_COOKIE_DOMAIN=
|
||||||
|
VISIONA_SESSION_COOKIE_SECURE=false
|
||||||
|
|
||||||
|
# Session TTL — 預設 7 天 absolute / 24h idle
|
||||||
|
VISIONA_SESSION_ABSOLUTE_TTL=168h
|
||||||
|
VISIONA_SESSION_IDLE_TTL=24h
|
||||||
|
|
||||||
|
# Relay 對外可達 URL(agent tunnel 用)— POST /api/pairing/exchange 會回給 agent。
|
||||||
|
# 雛形為空時會 fallback 到 wss://relay.visionA.cloud(placeholder)。
|
||||||
|
# 實機請設為實際可達的 WSS URL,例:wss://relay.visionA.cloud
|
||||||
|
VISIONA_RELAY_PUBLIC_URL=
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# remote-proxy
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 對 local agent 的 WebSocket tunnel port
|
||||||
|
VISIONA_TUNNEL_PORT=3800
|
||||||
|
|
||||||
|
# 對 api-server 的 internal HTTP port(不對外暴露)
|
||||||
|
VISIONA_PROXY_INTERNAL_PORT=3801
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Tunnel 心跳 / 掉線判定(對齊 tunnel.md §4.2)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
VISIONA_TUNNEL_HEARTBEAT_INTERVAL=10s
|
||||||
|
VISIONA_TUNNEL_IDLE_TIMEOUT=30s
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Storage(LocalFS — Phase 0 雛形;Phase 1 會改 S3)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 儲存根目錄(容器內;docker-compose 已 mount 成 volume)
|
||||||
|
VISIONA_STORAGE_BACKEND=localfs
|
||||||
|
VISIONA_STORAGE_LOCALFS_ROOT=./data/storage
|
||||||
|
|
||||||
|
# 瀏覽器 / 上傳 client 看到的 presigned URL base
|
||||||
|
# 本機開發:http://localhost:3721/storage
|
||||||
|
# docker-compose demo:同上(透過 host port mapping)
|
||||||
|
VISIONA_STORAGE_BASE_URL=http://localhost:3721/storage
|
||||||
|
VISIONA_STORAGE_LOCALFS_BASE_URL=http://localhost:3721/storage
|
||||||
|
|
||||||
|
# HMAC 簽章 secret — 用於 LocalFS presigned URL 與(Phase 1)pairing token hash
|
||||||
|
# ⚠️ 生產環境必改(openssl rand -hex 32 產生 64 字元 hex)
|
||||||
|
VISIONA_STORAGE_SIGNING_SECRET=CHANGE_ME_IN_PRODUCTION_use_openssl_rand_hex_32
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Model 上傳限制
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 單檔上限(MB)— Phase 0 規範 100 MB(PRD §8.4)
|
||||||
|
VISIONA_MODEL_MAX_SIZE_MB=100
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Pairing(local agent ↔ remote-proxy 配對)
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# 格式:vAc_ + 32 hex(見 security.md §1.3)
|
||||||
|
# 建議用:vAc_$(openssl rand -hex 16)
|
||||||
|
# 留空代表雛形 InMemoryPairingStore 會動態配發(前端呼叫 POST /api/pairing/token)
|
||||||
|
VISIONA_PAIRING_TOKEN=
|
||||||
43
visionA-backend/.gitignore
vendored
Normal file
43
visionA-backend/.gitignore
vendored
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# ---- Go 標準 -------------------------------------------------------------
|
||||||
|
# Binaries
|
||||||
|
*.exe
|
||||||
|
*.exe~
|
||||||
|
*.dll
|
||||||
|
*.so
|
||||||
|
*.dylib
|
||||||
|
|
||||||
|
# Test binary / coverage
|
||||||
|
*.test
|
||||||
|
*.out
|
||||||
|
coverage.txt
|
||||||
|
coverage.html
|
||||||
|
|
||||||
|
# Go workspace(本專案不使用 multi-module workspace)
|
||||||
|
go.work
|
||||||
|
go.work.sum
|
||||||
|
|
||||||
|
# ---- Build 產物 ----------------------------------------------------------
|
||||||
|
bin/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# ---- 環境變數 / 密鑰 -----------------------------------------------------
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# ---- IDE / Editor -------------------------------------------------------
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# ---- OS ------------------------------------------------------------------
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# ---- 本機開發資料 --------------------------------------------------------
|
||||||
|
# 雛形 LocalFS storage backend 的預設根目錄
|
||||||
|
data/
|
||||||
|
tmp/
|
||||||
131
visionA-backend/Makefile
Normal file
131
visionA-backend/Makefile
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
# visionA-backend Makefile
|
||||||
|
#
|
||||||
|
# 雙 binary 專案:api-server(REST/WS)+ remote-proxy(tunnel server)
|
||||||
|
# 對應文件:
|
||||||
|
# - .autoflow/04-architecture/build-deploy.md §1
|
||||||
|
# - 每個 target 都有 help 註解,`make help` 可看到清單
|
||||||
|
#
|
||||||
|
# 常用:
|
||||||
|
# make dev # 本機開發:平行跑兩個 binary
|
||||||
|
# make test # go test -race ./...
|
||||||
|
# make docker-build # 建 api-server + remote-proxy images
|
||||||
|
# make docker-compose-up # 啟動 docker-compose(api + proxy)
|
||||||
|
|
||||||
|
# ---- 變數 ----------------------------------------------------------------
|
||||||
|
BIN_DIR := bin
|
||||||
|
API_BIN := $(BIN_DIR)/api-server
|
||||||
|
PROXY_BIN := $(BIN_DIR)/remote-proxy
|
||||||
|
GO ?= go
|
||||||
|
GOFLAGS ?=
|
||||||
|
DOCKER ?= docker
|
||||||
|
COMPOSE_FILE := docker/docker-compose.yml
|
||||||
|
|
||||||
|
# VERSION 用於 docker image tag;預設 dev,或由 git describe 推斷
|
||||||
|
VERSION ?= $(shell git describe --tags --always --dirty 2>/dev/null || echo "dev")
|
||||||
|
|
||||||
|
# ---- 預設 target ---------------------------------------------------------
|
||||||
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
|
# ---- Build ---------------------------------------------------------------
|
||||||
|
.PHONY: build build-api build-proxy
|
||||||
|
|
||||||
|
build: build-api build-proxy ## 建置所有 binary(api-server + remote-proxy)
|
||||||
|
|
||||||
|
build-api: ## 建置 api-server
|
||||||
|
@mkdir -p $(BIN_DIR)
|
||||||
|
$(GO) build $(GOFLAGS) -o $(API_BIN) ./cmd/api-server
|
||||||
|
|
||||||
|
build-proxy: ## 建置 remote-proxy
|
||||||
|
@mkdir -p $(BIN_DIR)
|
||||||
|
$(GO) build $(GOFLAGS) -o $(PROXY_BIN) ./cmd/remote-proxy
|
||||||
|
|
||||||
|
# ---- Run -----------------------------------------------------------------
|
||||||
|
.PHONY: run-api run-proxy dev
|
||||||
|
|
||||||
|
run-api: ## 執行 api-server(本機開發)
|
||||||
|
$(GO) run ./cmd/api-server
|
||||||
|
|
||||||
|
run-proxy: ## 執行 remote-proxy(本機開發)
|
||||||
|
$(GO) run ./cmd/remote-proxy
|
||||||
|
|
||||||
|
# dev:純便利 target(非交付物,見 design-doc.md §1.9 N10)。
|
||||||
|
# 平行跑 remote-proxy(先起,因為 api-server 開機時會去打它)+ api-server。
|
||||||
|
# 任一 process 結束時,trap 會把另一個一起殺掉,避免殘留 zombie。
|
||||||
|
dev: build ## 本機開發:平行跑 remote-proxy + api-server(非交付物)
|
||||||
|
@echo "啟動 remote-proxy + api-server(Ctrl+C 結束)..."
|
||||||
|
@trap 'echo; echo "shutting down..."; kill 0' INT TERM EXIT; \
|
||||||
|
$(PROXY_BIN) & \
|
||||||
|
sleep 1; \
|
||||||
|
$(API_BIN) & \
|
||||||
|
wait
|
||||||
|
|
||||||
|
# ---- Test / Lint ---------------------------------------------------------
|
||||||
|
.PHONY: test test-race fmt vet lint
|
||||||
|
|
||||||
|
test: ## 執行單元測試(詳細輸出)
|
||||||
|
$(GO) test ./... -v
|
||||||
|
|
||||||
|
test-race: ## 執行單元 + 整合測試(race detector + coverage)
|
||||||
|
$(GO) test -race -coverprofile=coverage.out ./...
|
||||||
|
|
||||||
|
fmt: ## gofmt 格式化
|
||||||
|
$(GO) fmt ./...
|
||||||
|
|
||||||
|
vet: ## go vet 靜態分析
|
||||||
|
$(GO) vet ./...
|
||||||
|
|
||||||
|
lint: ## 靜態分析(優先 golangci-lint,若未安裝則 fallback 到 go vet)
|
||||||
|
@if command -v golangci-lint >/dev/null 2>&1; then \
|
||||||
|
golangci-lint run ./...; \
|
||||||
|
else \
|
||||||
|
echo "golangci-lint 未安裝,fallback 到 go vet"; \
|
||||||
|
$(GO) vet ./...; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ---- Docker --------------------------------------------------------------
|
||||||
|
.PHONY: docker-build docker-build-api docker-build-proxy \
|
||||||
|
docker-compose-up docker-compose-down docker-compose-logs docker-compose-ps
|
||||||
|
|
||||||
|
docker-build: docker-build-api docker-build-proxy ## 建置兩個 Docker images(api-server + remote-proxy)
|
||||||
|
|
||||||
|
docker-build-api: ## 建置 api-server image → visiona/api-server:$(VERSION)
|
||||||
|
$(DOCKER) build -f docker/Dockerfile.api-server \
|
||||||
|
-t visiona/api-server:$(VERSION) \
|
||||||
|
-t visiona/api-server:dev \
|
||||||
|
.
|
||||||
|
|
||||||
|
docker-build-proxy: ## 建置 remote-proxy image → visiona/remote-proxy:$(VERSION)
|
||||||
|
$(DOCKER) build -f docker/Dockerfile.remote-proxy \
|
||||||
|
-t visiona/remote-proxy:$(VERSION) \
|
||||||
|
-t visiona/remote-proxy:dev \
|
||||||
|
.
|
||||||
|
|
||||||
|
docker-compose-up: ## 啟動 docker-compose 環境(detach 模式)
|
||||||
|
@test -f .env || (echo "⚠️ 找不到 .env — 請先執行:cp .env.example .env" && exit 1)
|
||||||
|
$(DOCKER) compose -f $(COMPOSE_FILE) up -d --build
|
||||||
|
|
||||||
|
docker-compose-down: ## 停止並移除 docker-compose 容器
|
||||||
|
$(DOCKER) compose -f $(COMPOSE_FILE) down
|
||||||
|
|
||||||
|
docker-compose-logs: ## 跟蹤 docker-compose logs(Ctrl+C 離開)
|
||||||
|
$(DOCKER) compose -f $(COMPOSE_FILE) logs -f
|
||||||
|
|
||||||
|
docker-compose-ps: ## 顯示 docker-compose 服務狀態
|
||||||
|
$(DOCKER) compose -f $(COMPOSE_FILE) ps
|
||||||
|
|
||||||
|
# ---- Utility -------------------------------------------------------------
|
||||||
|
.PHONY: clean help tidy
|
||||||
|
|
||||||
|
tidy: ## 整理 go.mod / go.sum
|
||||||
|
$(GO) mod tidy
|
||||||
|
|
||||||
|
clean: ## 清除 build 產物
|
||||||
|
@rm -rf $(BIN_DIR) dist/ coverage.out
|
||||||
|
@echo "已清除 $(BIN_DIR)/ 與 dist/ 與 coverage.out"
|
||||||
|
|
||||||
|
help: ## 顯示本 Makefile 的所有 target
|
||||||
|
@awk 'BEGIN {FS = ":.*?## "; printf "\nvisionA-backend — Make targets\n\n"} \
|
||||||
|
/^[a-zA-Z_-]+:.*?## / { printf " \033[36m%-22s\033[0m %s\n", $$1, $$2 }' $(MAKEFILE_LIST)
|
||||||
|
@echo ""
|
||||||
|
@echo "常用:make dev / make test-race / make docker-compose-up"
|
||||||
|
@echo ""
|
||||||
393
visionA-backend/README.md
Normal file
393
visionA-backend/README.md
Normal file
@ -0,0 +1,393 @@
|
|||||||
|
# visionA-backend
|
||||||
|
|
||||||
|
> visionA Cloud 的後端服務。由 **`api-server`**(無狀態 REST/WS API)與 **`remote-proxy`**(有狀態 tunnel server)兩個 binary 組成。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## ⚠️ Phase 0 雛形警告
|
||||||
|
|
||||||
|
**這是雛形(prototype)版本,不是生產交付物。** 主要限制:
|
||||||
|
|
||||||
|
- 單一 user(永遠回 `demo-user`,無真正認證)
|
||||||
|
- 所有狀態 in-memory(重啟即消失,無 DB / Redis)
|
||||||
|
- Storage 走 LocalFS(無 S3)
|
||||||
|
- WebSocket proxy 尚未實作(所有 `/ws/*` 皆回 501)
|
||||||
|
- 單一 instance(無水平擴展)
|
||||||
|
|
||||||
|
完整限制見下方 [雛形範圍與限制](#雛形範圍與限制)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 架構總覽
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────┐
|
||||||
|
│ Browser / curl │
|
||||||
|
└──────────┬──────────┘
|
||||||
|
│ REST / WS (3721)
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ api-server │
|
||||||
|
│ (cmd/api-server) │
|
||||||
|
│ │
|
||||||
|
│ - REST + WS handler │
|
||||||
|
│ - Auth middleware(static)│
|
||||||
|
│ - ProxyClientStore │
|
||||||
|
│ (查詢 session) │
|
||||||
|
│ - Forwarder │
|
||||||
|
│ (轉發 HTTP 到 tunnel) │
|
||||||
|
│ - LocalFS storage (/storage)│
|
||||||
|
│ 無狀態,可水平擴展 │
|
||||||
|
└──────────┬───────────────────┘
|
||||||
|
│ internal HTTP (3801)
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ remote-proxy │
|
||||||
|
│ (cmd/remote-proxy) │
|
||||||
|
│ │
|
||||||
|
│ - Tunnel server (ws://3800) │
|
||||||
|
│ - yamux session store │
|
||||||
|
│ - /internal/forward/raw │
|
||||||
|
│ - /internal/session/:token │
|
||||||
|
│ │
|
||||||
|
│ ⚠️ 有狀態(單 instance) │
|
||||||
|
└──────────┬───────────────────┘
|
||||||
|
│ WebSocket + yamux (3800)
|
||||||
|
▼
|
||||||
|
┌──────────────────────────────┐
|
||||||
|
│ Local Agent │
|
||||||
|
│ (在客戶端機器上跑) │
|
||||||
|
│ │
|
||||||
|
│ 目前 demo 用 POC 的 │
|
||||||
|
│ edge-ai-server 當 client │
|
||||||
|
└──────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
詳細設計見:
|
||||||
|
- [`.autoflow/04-architecture/design-doc.md`](../.autoflow/04-architecture/design-doc.md)(§7 部署)
|
||||||
|
- [`.autoflow/04-architecture/TDD.md`](../.autoflow/04-architecture/TDD.md)
|
||||||
|
- [`.autoflow/04-architecture/api/api-spec.md`](../.autoflow/04-architecture/api/api-spec.md)(前端 REST API)
|
||||||
|
- [`.autoflow/04-architecture/api/api-internal.md`](../.autoflow/04-architecture/api/api-internal.md)(api-server ↔ remote-proxy)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 技術堆疊
|
||||||
|
|
||||||
|
| 層級 | 技術 | 備註 |
|
||||||
|
|------|------|------|
|
||||||
|
| 語言 | Go 1.26 | `go.mod` 鎖定 |
|
||||||
|
| HTTP framework | [Gin](https://github.com/gin-gonic/gin) + `gin-contrib/cors` | B4 導入 |
|
||||||
|
| Tunnel 傳輸 | `gorilla/websocket` + `hashicorp/yamux` | 沿用 POC `edge-ai-platform` |
|
||||||
|
| Logging | `log/slog`(stdlib) | JSON handler,結構化輸出 |
|
||||||
|
| ID 生成 | `google/uuid` | request-id / demo seed |
|
||||||
|
| 單元測試 | `stretchr/testify` | B2 導入 |
|
||||||
|
| 配置 | 12-Factor App | 全走 env,不寫死 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 快速啟動(10 分鐘起步)
|
||||||
|
|
||||||
|
### 前置
|
||||||
|
|
||||||
|
- Go 1.26+(本機 run)或 Docker 27+(容器 run)
|
||||||
|
- macOS / Linux(Windows 未測試)
|
||||||
|
|
||||||
|
### 方式 A:本機 `go run`(最快,適合開發)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd visionA-backend
|
||||||
|
|
||||||
|
# 1) 一鍵跑 remote-proxy + api-server(任一 Ctrl+C 兩個都會停)
|
||||||
|
make dev
|
||||||
|
|
||||||
|
# 另開 terminal 驗證:
|
||||||
|
curl http://localhost:3721/healthz
|
||||||
|
# {"status":"ok"}
|
||||||
|
|
||||||
|
curl -X POST http://localhost:3721/api/auth/login \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"email":"demo@visionA.local","password":"any"}'
|
||||||
|
# {"success":true,"data":{"user":{"id":"demo-user",...},"access_token":"demo-access-token",...}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 方式 B:Docker Compose(接近生產拓撲)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd visionA-backend
|
||||||
|
|
||||||
|
# 1) 複製環境變數範本
|
||||||
|
cp .env.example .env
|
||||||
|
# (視情況編輯 .env — 通常預設就能跑)
|
||||||
|
|
||||||
|
# 2) 建 image + 啟動
|
||||||
|
make docker-compose-up
|
||||||
|
|
||||||
|
# 3) 驗證
|
||||||
|
curl http://localhost:3721/healthz
|
||||||
|
curl -X POST http://localhost:3721/api/auth/login \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"email":"demo@visionA.local","password":"any"}'
|
||||||
|
|
||||||
|
# 4) 跟 logs
|
||||||
|
make docker-compose-logs
|
||||||
|
|
||||||
|
# 5) 停
|
||||||
|
make docker-compose-down
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ports:**
|
||||||
|
- `3721`:api-server REST + WS(對前端)
|
||||||
|
- `3800`:remote-proxy tunnel WS(對 local agent)
|
||||||
|
- `3801`:remote-proxy internal HTTP(compose 內部,不對外)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 如何用 POC `edge-ai-server` 驗證 tunnel
|
||||||
|
|
||||||
|
雛形不包含 local agent(Q3 決策:local agent 模組 Phase 1 才做)。要驗證 tunnel 整條鏈路,用 POC `edge-ai-platform/edge-ai-server` 當 tunnel client。
|
||||||
|
|
||||||
|
### 步驟
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Terminal 1 — 起 visionA-backend
|
||||||
|
cd visionA-backend
|
||||||
|
make dev
|
||||||
|
# 或 docker compose up
|
||||||
|
|
||||||
|
# Terminal 2 — 申請 pairing token(雛形:POST 就會配發一個)
|
||||||
|
curl -X POST http://localhost:3721/api/pairing/token \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"name":"demo"}'
|
||||||
|
# { "success": true, "data": { "token": "vAc_...", "expires_at": "..." } }
|
||||||
|
|
||||||
|
# Terminal 3 — POC edge-ai-server 當 tunnel client 接上來
|
||||||
|
cd /path/to/edge-ai-platform
|
||||||
|
./dist/edge-ai-server \
|
||||||
|
--relay-url=ws://localhost:3800/tunnel/connect \
|
||||||
|
--relay-token=vAc_<貼上上一步的 token>
|
||||||
|
|
||||||
|
# Terminal 2 — 驗證 tunnel 已連上
|
||||||
|
curl -H "Authorization: Bearer demo-access-token" \
|
||||||
|
http://localhost:3721/api/pairing/status
|
||||||
|
# { "success": true, "data": { "connected": true, ... } }
|
||||||
|
|
||||||
|
# 打到 local agent(會被 forward 過 tunnel)
|
||||||
|
curl -H "Authorization: Bearer demo-access-token" \
|
||||||
|
http://localhost:3721/api/devices/scan
|
||||||
|
```
|
||||||
|
|
||||||
|
### 或:用 B3 的 fake tunnel client 寫小 demo
|
||||||
|
|
||||||
|
若不想起 POC,`cmd/api-server/b5_integration_test.go` 裡的 `startFakeTunnelClient` 是現成的 60 行範例,直接複製到 `cmd/tunnel-demo/main.go` 就能跑。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目錄結構
|
||||||
|
|
||||||
|
```
|
||||||
|
visionA-backend/
|
||||||
|
├── cmd/
|
||||||
|
│ ├── api-server/ # REST/WS API server(無狀態)
|
||||||
|
│ │ ├── main.go
|
||||||
|
│ │ ├── seed.go # --seed-demo-data 用的示範資料
|
||||||
|
│ │ ├── integration_test.go # B4 端到端測試
|
||||||
|
│ │ └── b5_integration_test.go # B5 端到端測試(含 tunnel forward + model upload)
|
||||||
|
│ └── remote-proxy/ # tunnel server(有狀態,持有 session in-memory)
|
||||||
|
│ └── main.go
|
||||||
|
├── internal/
|
||||||
|
│ ├── api/ # API handlers + Gin router + middleware(B4 + B5)
|
||||||
|
│ ├── auth/ # AuthService / AuthProvider / PairingStore(雛形 Static + InMemory)
|
||||||
|
│ ├── session/ # Store / Handle / Forwarder / ProxyClient
|
||||||
|
│ ├── device/ # Device domain + InMemoryRepository
|
||||||
|
│ ├── model/ # Model domain + InMemoryRepository
|
||||||
|
│ ├── cluster/ # Cluster domain(POC 複製,dispatcher 留 TODO)
|
||||||
|
│ ├── relay/ # tunnel server + internal forward API(POC 改造)
|
||||||
|
│ ├── wsconn/ # WebSocket ↔ net.Conn adapter(POC 複製)
|
||||||
|
│ ├── converter/ # StubClient(Phase 2 才實作)
|
||||||
|
│ ├── storage/ # Store interface + LocalFSStore(HMAC presigned URL)
|
||||||
|
│ ├── config/ # Config + Load()(12-Factor)
|
||||||
|
│ └── logger/ # slog JSON logger wrapper
|
||||||
|
├── docker/
|
||||||
|
│ ├── Dockerfile.api-server # multi-stage,non-root,healthcheck
|
||||||
|
│ ├── Dockerfile.remote-proxy
|
||||||
|
│ └── docker-compose.yml
|
||||||
|
├── .env.example # 環境變數範本(commit)
|
||||||
|
├── .gitignore # 已排除 .env / bin/ / data/
|
||||||
|
├── Makefile # build / dev / test / docker-* 等 targets
|
||||||
|
├── go.mod / go.sum
|
||||||
|
└── README.md # 本檔
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API 端點摘要
|
||||||
|
|
||||||
|
完整規格見 [`.autoflow/04-architecture/api/api-spec.md`](../.autoflow/04-architecture/api/api-spec.md)。
|
||||||
|
|
||||||
|
| 群組 | 端點 | 說明 |
|
||||||
|
|------|------|------|
|
||||||
|
| System | `GET /healthz` | liveness/readiness(無需認證) |
|
||||||
|
| System | `GET /api/system/health` | tunnel / agent 連線狀態 |
|
||||||
|
| System | `GET /api/system/info` | 版本 + build 資訊 |
|
||||||
|
| Auth | `POST /api/auth/login` | 雛形永遠回 `demo-user` |
|
||||||
|
| Auth | `POST /api/auth/register` | 雛形 501 |
|
||||||
|
| Auth | `GET /api/auth/me` | 當前 user 資訊 |
|
||||||
|
| Pairing | `POST /api/pairing/token` | 申請 pairing token(vAc_ + 32 hex) |
|
||||||
|
| Pairing | `GET /api/pairing/status` | 目前 user 的 tunnel 狀態 |
|
||||||
|
| Pairing | `GET /api/pairing/tokens` | 列出已簽發的 token |
|
||||||
|
| Pairing | `DELETE /api/pairing/tokens/:id` | revoke token |
|
||||||
|
| Devices | `GET /api/devices` | 列出裝置 |
|
||||||
|
| Devices | `POST /api/devices/scan` | 觸發 local agent 掃 USB(透過 tunnel) |
|
||||||
|
| Devices | `DELETE /api/devices/:id` | unpair(同時關 tunnel) |
|
||||||
|
| Models | `POST /api/models/init` | 兩階段上傳 step 1(拿 presigned PUT URL) |
|
||||||
|
| Models | `PUT /storage/:signed` | 實際上傳檔案(HMAC 驗簽) |
|
||||||
|
| Models | `POST /api/models/:id/finalize` | 兩階段上傳 step 2(marked ready) |
|
||||||
|
| Models | `GET /api/models` | 列出模型 |
|
||||||
|
| Clusters | `GET /api/clusters` | 列出 cluster(骨架) |
|
||||||
|
| Storage | `GET /storage/:signed` | presigned download |
|
||||||
|
| Camera / Inference | `/api/cameras/*`、`/api/inference/*` | proxy 到 local agent |
|
||||||
|
| WebSocket | `/ws/*` | **雛形 501**(B7 之後補) |
|
||||||
|
|
||||||
|
### 錯誤格式(統一)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": false,
|
||||||
|
"error": {
|
||||||
|
"code": "TUNNEL_DISCONNECTED",
|
||||||
|
"message": "local agent 未連線或 tunnel 斷開",
|
||||||
|
"request_id": "req_abc123"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
錯誤碼清單見 [`internal/api/errors.go`](internal/api/errors.go)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 環境變數
|
||||||
|
|
||||||
|
詳見 [`.env.example`](.env.example)。常用:
|
||||||
|
|
||||||
|
| 變數 | 預設 | 說明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `VISIONA_API_PORT` | `3721` | api-server listen port |
|
||||||
|
| `VISIONA_TUNNEL_PORT` | `3800` | remote-proxy 對 local agent 的 WS port |
|
||||||
|
| `VISIONA_PROXY_INTERNAL_PORT` | `3801` | remote-proxy 對 api-server 的內部 HTTP port |
|
||||||
|
| `VISIONA_PROXY_INTERNAL_URL` | `http://localhost:3801` | api-server 連 remote-proxy 用(docker compose 會覆為 `http://remote-proxy:3801`) |
|
||||||
|
| `VISIONA_SEED_DEMO_DATA` | `false` | 啟動時塞示範資料(device + model + pairing) |
|
||||||
|
| `VISIONA_STORAGE_SIGNING_SECRET` | `dev-signing-secret-...` | presigned URL HMAC secret(**生產必改**) |
|
||||||
|
| `VISIONA_STATIC_USER_ID` | `demo-user` | 雛形 static auth 的 user id |
|
||||||
|
| `VISIONA_MODEL_MAX_SIZE_MB` | `100` | 模型上傳大小上限 |
|
||||||
|
| `VISIONA_CORS_ALLOWED_ORIGINS` | `http://localhost:3000` | CORS 白名單(逗號分隔) |
|
||||||
|
| `VISIONA_LOG_LEVEL` | `info` | debug / info / warn / error |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 測試
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 所有單元測試 + integration test + race detector
|
||||||
|
make test-race
|
||||||
|
|
||||||
|
# 僅 go vet / gofmt check
|
||||||
|
make lint
|
||||||
|
|
||||||
|
# 詳細輸出
|
||||||
|
make test
|
||||||
|
```
|
||||||
|
|
||||||
|
覆蓋面:
|
||||||
|
- 單元測試:`internal/{auth,session,device,model,config,storage,api,relay,wsconn,logger,converter}`
|
||||||
|
- Integration:
|
||||||
|
- `cmd/api-server/integration_test.go`(B4:api-server → remote-proxy → fake tunnel)
|
||||||
|
- `cmd/api-server/b5_integration_test.go`(B5:完整端到端 — login / scan / model upload / tunnel disconnect)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 開發者指南
|
||||||
|
|
||||||
|
| 我想做的事 | 看哪裡 |
|
||||||
|
|-----------|--------|
|
||||||
|
| 新增一個 REST endpoint | `internal/api/` 找類似的 handler 複製;proxy 類用 `newProxyHandler` |
|
||||||
|
| 改動 API 規格 | 先改 `.autoflow/04-architecture/api/api-spec.md` 再改 code |
|
||||||
|
| 加環境變數 | `internal/config/config.go` + `load.go` + `.env.example` + 本 README |
|
||||||
|
| 改 tunnel 協定 | `.autoflow/04-architecture/tunnel.md` + `internal/relay/` + `internal/session/forwarder.go` |
|
||||||
|
| 追蹤某個 Review 問題 | `.autoflow/05-implementation/review/` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Known Issues
|
||||||
|
|
||||||
|
### Tunnel client 不在本 repo
|
||||||
|
|
||||||
|
visionA-backend 只實作 tunnel **server** 端(`internal/relay/` + `internal/wsconn/`);tunnel **client** 由 visionA Agent 實作(從 POC `edge-ai-platform/server/internal/tunnel/client.go` 複製)。本 repo 過去曾保留一份 `internal/tunnel/` 副本,但因從未被 import 且會造成「兩處需要同步修補」的維護負擔,已於 2026-04-21 刪除(決策見 [`.autoflow/04-architecture/adr/adr-008-tunnel-client-reuse.md`](../.autoflow/04-architecture/adr/adr-008-tunnel-client-reuse.md))。
|
||||||
|
|
||||||
|
POC `client.go` 的 `backoff()` 有單位 mix bug(`attempt >= 1` 時永遠回 30 秒),visionA Agent 建立後需在自己的 repo 修復;本 repo 不再追蹤此 issue。
|
||||||
|
|
||||||
|
### WebSocket proxy 未實作
|
||||||
|
|
||||||
|
所有 `/ws/*` endpoint 目前回 `501 Not Implemented`。原因:`Forwarder.ForwardWebSocket` 需要 `http.Hijacker` + 雙向 `io.Copy` 架構,B7 範圍外。
|
||||||
|
|
||||||
|
前端呼叫 `/ws/*` 時會收到 JSON body `{ "success": false, "error": { "code": "NOT_IMPLEMENTED", ... } }`,瀏覽器 WebSocket 會 fail upgrade。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 雛形範圍與限制
|
||||||
|
|
||||||
|
### 是什麼
|
||||||
|
|
||||||
|
- 雙 binary 架構驗證(api-server 無狀態 + remote-proxy in-memory session)
|
||||||
|
- REST + Tunnel 完整鏈路:browser → api-server → internal HTTP → remote-proxy → yamux → local agent
|
||||||
|
- 兩階段模型上傳(init → PUT presigned → finalize)
|
||||||
|
- Docker image + docker-compose 可跑
|
||||||
|
|
||||||
|
### 不是什麼
|
||||||
|
|
||||||
|
| 項目 | 雛形 | Phase 1+ |
|
||||||
|
|------|------|---------|
|
||||||
|
| Auth | static,一律 `demo-user` | OIDC / Clerk |
|
||||||
|
| 資料庫 | 無(全 in-memory) | PostgreSQL |
|
||||||
|
| Session 存放 | `remote-proxy` 進程內 | Redis(支援水平擴展) |
|
||||||
|
| 檔案儲存 | LocalFS(`./data/storage`) | S3 |
|
||||||
|
| 支援多 user | ❌ | ✅ |
|
||||||
|
| Rate limiting | ❌ | ✅ |
|
||||||
|
| Audit log | ❌ | ✅ |
|
||||||
|
| WebSocket proxy | 501 stub | ✅ |
|
||||||
|
| TLS | ❌(http only) | ✅(ALB / NLB termination) |
|
||||||
|
| 水平擴展 | ❌ | ✅(api-server 可;remote-proxy 需加 shared session store) |
|
||||||
|
|
||||||
|
### 重啟即消失
|
||||||
|
|
||||||
|
`make docker-compose-down` 或 restart 後:
|
||||||
|
- 所有 pairing token 消失
|
||||||
|
- 所有 device / model 紀錄消失(除非 `VISIONA_SEED_DEMO_DATA=true`)
|
||||||
|
- Storage 檔案保留(`./data/storage` 被 mount 為 volume)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phase 1 路線圖
|
||||||
|
|
||||||
|
- [ ] 真 auth(OIDC via Clerk / Auth0)
|
||||||
|
- [ ] PostgreSQL + Redis(參考 `.autoflow/04-architecture/database.md`)
|
||||||
|
- [ ] S3 / R2 storage backend(替換 LocalFSStore)
|
||||||
|
- [ ] WebSocket proxy(hijack + 雙向 io.Copy)
|
||||||
|
- [ ] 多裝置 / 多 cluster 支援
|
||||||
|
- [ ] Rate limiting + audit log
|
||||||
|
- [ ] K8s / ECS deployment(參考 `.autoflow/04-architecture/build-deploy.md` §7)
|
||||||
|
- [ ] CI/CD pipeline(GitHub Actions)
|
||||||
|
- [ ] Local agent 模組(獨立 binary,取代 POC edge-ai-server 當 tunnel client)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目前實作進度
|
||||||
|
|
||||||
|
- [x] **B1** 專案初始化(go.mod、目錄骨架、Makefile、.gitignore)
|
||||||
|
- [x] **B2** 共用 `internal/` 模組(core interface + in-memory 實作 + 單元測試)
|
||||||
|
- [x] **B3** `cmd/remote-proxy` + relay / tunnel / wsconn / cluster
|
||||||
|
- [x] **B4** `cmd/api-server` + `internal/api` 骨架 + Forwarder + ProxyClient
|
||||||
|
- [x] **B5** API handlers 雛形(20+ endpoint 實作、兩階段上傳、tunnel forward)
|
||||||
|
- [x] **B6** Docker image + docker-compose(multi-stage + non-root + healthcheck)
|
||||||
|
- [x] **B7** README + `.env.example` + Makefile 補完
|
||||||
|
|
||||||
|
完整任務紀錄見 [`.autoflow/progress.md`](../.autoflow/progress.md)。
|
||||||
0
visionA-backend/cmd/api-server/.gitkeep
Normal file
0
visionA-backend/cmd/api-server/.gitkeep
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
// all_endpoints_require_auth_test.go — Phase 0.7 security regression test.
|
||||||
|
//
|
||||||
|
// 對齊 .autoflow/05-implementation/review/phase-0.7-security-audit.md s1。
|
||||||
|
//
|
||||||
|
// 目的:對所有 protected endpoint 發無 cookie request,必須回 401。
|
||||||
|
//
|
||||||
|
// 為什麼需要:
|
||||||
|
//
|
||||||
|
// Phase 0.7 audit 發現 visionA-backend 13+ 處 handler 用 resolveUserID 寬鬆 fallback
|
||||||
|
// 到 demo-user。即使 Backend 完成 Fix #1-#5(移除 fallback、改 strict mode)後,
|
||||||
|
// 未來任何一條漏套 AuthMiddleware 的新 endpoint 都會立刻打破 multi-tenant 隔離。
|
||||||
|
//
|
||||||
|
// 這條測試是 C1 fallback 的長期防呆 — 任何未來 endpoint 漏保護都會立刻被 fail:
|
||||||
|
//
|
||||||
|
// - 新 endpoint 註冊時忘了套 middleware
|
||||||
|
// - middleware 順序錯誤
|
||||||
|
// - router group 套錯(例如把保護的 path 加在 r 而非 apiGroup)
|
||||||
|
//
|
||||||
|
// 能力與限制:
|
||||||
|
//
|
||||||
|
// ✅ 能驗:每個 protected endpoint 都有套 AuthMiddleware(沒帶 cookie → 401)
|
||||||
|
// ✅ 能驗:未來新增 endpoint 漏套會立刻 fail
|
||||||
|
// ❌ 不能驗:middleware 內部邏輯(middleware_test.go 在做)
|
||||||
|
// ❌ 不能驗:跨用戶 authorization(user A 不能存取 B 的資源 — 另一條測試)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// publicPaths 是「不需要認證」的明確 endpoint 清單(method + path 精確匹配)。
|
||||||
|
//
|
||||||
|
// 對齊 audit 報告 endpoint × middleware 對照表:這些 endpoint 故意設計成 public —
|
||||||
|
//
|
||||||
|
// - /healthz → K8s liveness/readiness 用,不能要 cookie
|
||||||
|
// - /api/auth/login → 起始 OIDC 登入流程,user 還沒登入
|
||||||
|
// - /api/auth/callback → OIDC IdP 302 回來,user 還沒登入
|
||||||
|
// - /api/pairing/exchange → agent 還沒 session token,用 pairing token 換
|
||||||
|
//
|
||||||
|
// 任何往這份清單裡新加 endpoint 的 PR 都該特別 review — 你正在繞過 OIDC 保護。
|
||||||
|
var publicPaths = map[string]bool{
|
||||||
|
"GET /healthz": true,
|
||||||
|
"GET /api/auth/login": true,
|
||||||
|
"GET /api/auth/callback": true,
|
||||||
|
"POST /api/pairing/exchange": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// publicPrefixes 是「整個 path prefix 都不走 OIDC AuthMiddleware」的清單。
|
||||||
|
//
|
||||||
|
// - /storage/* — 用 HMAC presigned URL 驗簽(api-spec.md §10),不是 cookie
|
||||||
|
// - /ws/* — 雛形 stub 一律 501,註冊在 r 而非 apiGroup(stubs.go:70-85)。
|
||||||
|
// 目前無認證 → 501;**未來補實作 WebSocket proxy 時必須套 auth**,
|
||||||
|
// 屆時應從這份清單移除。TODO(B7): 移到 protected。
|
||||||
|
var publicPrefixes = []string{
|
||||||
|
"/storage/",
|
||||||
|
"/ws/",
|
||||||
|
}
|
||||||
|
|
||||||
|
// pathParamReplacements 把 gin route 的 path param(:id / :token / *filepath)
|
||||||
|
// 換成具體值,讓 router 能 match 到實際 handler。
|
||||||
|
//
|
||||||
|
// 注意:這裡的具體值不需要是「資料庫真的存在的 ID」 — 我們只在乎 router 路由正確
|
||||||
|
// (middleware 是 router-level 的,沒到 handler 之前就會被 401 擋下)。
|
||||||
|
func replacePathParams(path string) string {
|
||||||
|
// 順序很重要:*filepath 用 catch-all,要先處理;其他 :param 用 simple replace。
|
||||||
|
if idx := strings.Index(path, "*"); idx >= 0 {
|
||||||
|
// 例:/storage/*filepath → /storage/anything
|
||||||
|
return path[:idx] + "anything"
|
||||||
|
}
|
||||||
|
|
||||||
|
// :param → "test-value"(任意非空字串就行)
|
||||||
|
parts := strings.Split(path, "/")
|
||||||
|
for i, p := range parts {
|
||||||
|
if strings.HasPrefix(p, ":") {
|
||||||
|
parts[i] = "test-value"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAllAPIEndpointsRequire401WithoutCookie 對所有 protected endpoint 發無 cookie
|
||||||
|
// request,必須回 401。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security regression test (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md s1)。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 從 fixture 拿 *gin.Engine,列出所有註冊的 routes
|
||||||
|
// 2. 對每條 route:跳過 publicPaths / publicPrefixes 名單;其他全部要求 401
|
||||||
|
// 3. 發 request 時不帶 cookie、不帶 Authorization header
|
||||||
|
// 4. 驗 status code == 401 UNAUTHORIZED
|
||||||
|
//
|
||||||
|
// 失敗時的解讀:
|
||||||
|
// - 某個 protected endpoint 回 200/500/501 而非 401 → 該路徑沒套 AuthMiddleware;
|
||||||
|
// 檢查是不是註冊在 r 而非 apiGroup,或 middleware 順序錯誤。
|
||||||
|
// - 某個 endpoint 回 200 帶 demo-user 資料 → C1 fallback 還在沒移除(Backend Fix #1-#5
|
||||||
|
// 未完成)。應在 Backend 修完後再跑,或先標 t.Skip。
|
||||||
|
func TestAllAPIEndpointsRequire401WithoutCookie(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, f.router, "fixture.router 必須非 nil — 確認 setupFixture 有 set router")
|
||||||
|
|
||||||
|
routes := f.router.Routes()
|
||||||
|
require.NotEmpty(t, routes, "router 應至少註冊一條 route")
|
||||||
|
|
||||||
|
var (
|
||||||
|
coveredCount int // 真正測 401 的數量
|
||||||
|
skippedCount int // 跳過(public)的數量
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, route := range routes {
|
||||||
|
key := route.Method + " " + route.Path
|
||||||
|
|
||||||
|
// 跳過明確 public 的 endpoint
|
||||||
|
if publicPaths[key] {
|
||||||
|
skippedCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
skipByPrefix := false
|
||||||
|
for _, p := range publicPrefixes {
|
||||||
|
if strings.HasPrefix(route.Path, p) {
|
||||||
|
skipByPrefix = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if skipByPrefix {
|
||||||
|
skippedCount++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
coveredCount++
|
||||||
|
|
||||||
|
// 用 t.Run 讓每個 endpoint 是獨立 subtest — 失敗時看得到具體哪條
|
||||||
|
t.Run(key, func(t *testing.T) {
|
||||||
|
actualPath := replacePathParams(route.Path)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(route.Method, actualPath, nil)
|
||||||
|
// **故意不**設 cookie、不設 Authorization header — 模擬完全沒認證
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
f.router.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// 必須是 401。不可以是:
|
||||||
|
// - 200 / 2xx → 通過 middleware,handler 拿 fallback userID 回了資料(C1 latent break)
|
||||||
|
// - 500 → middleware 沒套或內部 panic(更糟)
|
||||||
|
// - 501 → handler 是 stub 但 middleware 沒擋下(middleware 沒套)
|
||||||
|
// - 404 → router path mismatch(測試 setup bug)
|
||||||
|
// - 502 → proxy handler 沒被 middleware 擋下(middleware 沒套)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code,
|
||||||
|
"%s 應回 401(沒帶 cookie),實際 %d;body=%s\n"+
|
||||||
|
"可能原因:\n"+
|
||||||
|
" 1. 該路徑沒套 AuthMiddleware(檢查 NewRouter 是註冊在 r 還是 apiGroup)\n"+
|
||||||
|
" 2. middleware 註冊順序錯誤(AuthMiddleware 必須在 handler 之前)\n"+
|
||||||
|
" 3. C1 fallback 還在 — Backend Fix #1-#5 未完成(看 audit 報告)",
|
||||||
|
key, w.Code, w.Body.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 確保我們真的有測到 endpoint,不是測試 setup 出錯導致全部 skip
|
||||||
|
require.Greater(t, coveredCount, 10,
|
||||||
|
"預期至少 10+ 個 protected endpoint 被 cover;實際 covered=%d, skipped=%d。"+
|
||||||
|
"若異常偏低代表 fixture 或路由註冊出問題", coveredCount, skippedCount)
|
||||||
|
|
||||||
|
t.Logf("covered %d protected endpoints, skipped %d public endpoints",
|
||||||
|
coveredCount, skippedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPublicEndpointsListIsExhaustive 防呆:確認 publicPaths 與 publicPrefixes
|
||||||
|
// 真的對應到實際 router 上有的 endpoint,而不是過期清單。
|
||||||
|
//
|
||||||
|
// 為什麼需要:如果未來把某個 public endpoint 改名(例 /api/auth/login → /api/auth/oidc/login)
|
||||||
|
// 但忘了更新 publicPaths,主測試會把新的 path 當 protected 然後驗 401。雖然該驗
|
||||||
|
// 也是對的(新 path 就是 protected),但會讓人誤以為「主測試覆蓋的 public 已經包含 login」。
|
||||||
|
//
|
||||||
|
// 這條測試讓「publicPaths 列了卻沒對應實際路由的 entry」變成 fail。
|
||||||
|
func TestPublicEndpointsListMatchesRouter(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
require.NotNil(t, f.router)
|
||||||
|
routes := f.router.Routes()
|
||||||
|
|
||||||
|
registered := make(map[string]bool, len(routes))
|
||||||
|
for _, r := range routes {
|
||||||
|
registered[r.Method+" "+r.Path] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for key := range publicPaths {
|
||||||
|
assert.True(t, registered[key],
|
||||||
|
"publicPaths 列了 %q 但 router 沒有這條 route — 是否已重新命名?", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
364
visionA-backend/cmd/api-server/b5_integration_test.go
Normal file
364
visionA-backend/cmd/api-server/b5_integration_test.go
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
// b5_integration_test.go — B5 各 handler 的 end-to-end integration tests。
|
||||||
|
//
|
||||||
|
// 和 integration_test.go 使用同一個 testFixture / startFakeTunnelClient,只是
|
||||||
|
// 驗證的端點不同。這個檔案聚焦 B5 新增的 handler:
|
||||||
|
// - /api/auth/login + /api/auth/me(OIDC 流程跑通)
|
||||||
|
// - /api/pairing/tokens(list)
|
||||||
|
// - /api/devices 列表(驗證合併雲端 repo + session 狀態)
|
||||||
|
// - /api/devices/scan 走 tunnel(proxy 到 fake local agent)
|
||||||
|
// - /api/models/init + PUT /storage/... + /api/models/:id/finalize(完整上傳流程)
|
||||||
|
// - /api/clusters 回空陣列
|
||||||
|
//
|
||||||
|
// 命名刻意加 `B5_` 前綴便於從失敗輸出快速定位到本檔。
|
||||||
|
//
|
||||||
|
// OB5 起:所有打 /api/* 的 test 都改用 fixture.AuthenticatedClient(t, userID, email)
|
||||||
|
// 走完整 OIDC login flow 拿 cookie 再呼叫,不再使用 StaticAuthService。
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestB5_AuthLoginAndMe 驗證 OIDC login flow + /auth/me 能跑通。
|
||||||
|
//
|
||||||
|
// OB5 起 POST /api/auth/login 一律 410 Gone(指向 GET /api/auth/login redirect flow);
|
||||||
|
// 真正的登入是 AuthenticatedClient 內部執行的 GET /api/auth/login → callback → cookie。
|
||||||
|
func TestB5_AuthLoginAndMe(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
const wantSub = "user-b5-login"
|
||||||
|
const wantEmail = "b5-login@visiona.local"
|
||||||
|
client := f.AuthenticatedClient(t, wantSub, wantEmail)
|
||||||
|
|
||||||
|
// 1. POST /api/auth/login → 410 Gone(OIDC mode)
|
||||||
|
loginBody := map[string]string{"email": "foo@bar.local", "password": "whatever"}
|
||||||
|
bodyBytes, _ := json.Marshal(loginBody)
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/auth/login", "application/json", bytes.NewReader(bodyBytes))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusGone, resp.StatusCode, "OIDC mode 下 POST /api/auth/login 應回 410")
|
||||||
|
|
||||||
|
// 2. GET /api/auth/me — 應該回 OIDC sub
|
||||||
|
resp2, err := client.Get(f.apiServer.URL + "/api/auth/me")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp2.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp2.StatusCode)
|
||||||
|
|
||||||
|
var meResp map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp2.Body).Decode(&meResp))
|
||||||
|
meData := meResp["data"].(map[string]any)
|
||||||
|
assert.Equal(t, wantSub, meData["user_id"])
|
||||||
|
assert.Equal(t, wantEmail, meData["email"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_AuthRegisterReturns501 驗證雛形不實作註冊。
|
||||||
|
//
|
||||||
|
// 註:此 endpoint 走 AuthMiddleware(需 cookie),雛形語意上「已登入也不能註冊」。
|
||||||
|
func TestB5_AuthRegisterReturns501(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/auth/register", "application/json", strings.NewReader(`{}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNotImplemented, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_PairingTokensListAndRevoke 驗證 list / revoke 端對端流程。
|
||||||
|
func TestB5_PairingTokensListAndRevoke(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 建 2 個 token
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GET list — 應該看到 2 筆
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/pairing/tokens")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var listBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&listBody))
|
||||||
|
arr := listBody["data"].([]any)
|
||||||
|
assert.Len(t, arr, 2, "應有 2 個 pairing token")
|
||||||
|
|
||||||
|
// 取其中一個 token 的 prefix(雛形 path 傳 plaintext)— 這個 test 改走 create 拿 plaintext
|
||||||
|
resp2, err := client.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp2.Body.Close()
|
||||||
|
var newTok map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp2.Body).Decode(&newTok))
|
||||||
|
plaintext := newTok["data"].(map[string]any)["token"].(string)
|
||||||
|
|
||||||
|
// DELETE revoke
|
||||||
|
req, _ := http.NewRequest(http.MethodDelete, f.apiServer.URL+"/api/pairing/tokens/"+plaintext, nil)
|
||||||
|
revResp, err := client.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer revResp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNoContent, revResp.StatusCode)
|
||||||
|
|
||||||
|
// 再 revoke 不存在的 token → 404
|
||||||
|
req2, _ := http.NewRequest(http.MethodDelete, f.apiServer.URL+"/api/pairing/tokens/vAc_00000000000000000000000000000000", nil)
|
||||||
|
notFoundResp, err := client.Do(req2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer notFoundResp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNotFound, notFoundResp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_DevicesList 驗證 GET /api/devices 讀 repo + 合併 session 狀態。
|
||||||
|
func TestB5_DevicesList(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 塞一台 device 到 repo(模擬使用者已配對)
|
||||||
|
ctx := context.Background()
|
||||||
|
// 注意:setupFixture 的 router 內部 repo 是新的,不能從外部取到;
|
||||||
|
// 這個 test 只能走「先 pairing token → tunnel 連上 → session.List 有東西」
|
||||||
|
// 的間接驗證,但雲端 repo 內沒有 device。因此預期回空陣列 + 200。
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/devices")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, true, body["success"])
|
||||||
|
arr, ok := body["data"].([]any)
|
||||||
|
require.True(t, ok, "data 應為 array")
|
||||||
|
assert.Empty(t, arr, "預設沒有 device")
|
||||||
|
|
||||||
|
_ = ctx
|
||||||
|
// device.ErrNotFound 在這裡不會出現;留 import 避免 lint
|
||||||
|
_ = device.ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_DevicesScan_ViaTunnel 驗證 scan 端點走 tunnel proxy 到 fake local agent。
|
||||||
|
func TestB5_DevicesScan_ViaTunnel(t *testing.T) {
|
||||||
|
// fake local agent:對 /api/devices/scan 回一段 JSON
|
||||||
|
localHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
require.Equal(t, "/api/devices/scan", r.URL.Path)
|
||||||
|
require.Equal(t, http.MethodPost, r.Method)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"scanned": 1,
|
||||||
|
"devices": []map[string]any{
|
||||||
|
{"id": "kl520-abc", "type": "kl520"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
f := setupFixture(t, localHandler)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 建立 tunnel
|
||||||
|
const token = "vAc_ccccccccccccccccccccccccccccccc1"
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, token,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// POST /api/devices/scan
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/devices/scan", "application/json", nil)
|
||||||
|
require.NoError(t, err, "scan 應該走 proxy 並成功")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.EqualValues(t, 1, body["scanned"])
|
||||||
|
assert.NotEmpty(t, body["devices"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_DevicesScan_TunnelDisconnected 驗證 tunnel 不存在時回 502 + TUNNEL_DISCONNECTED。
|
||||||
|
func TestB5_DevicesScan_TunnelDisconnected(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 不起 tunnel → 直接打 scan
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/devices/scan", "application/json", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
errObj := body["error"].(map[string]any)
|
||||||
|
assert.Equal(t, "TUNNEL_DISCONNECTED", errObj["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_ModelUploadFlow 驗證完整的模型上傳流程:init → PUT → finalize。
|
||||||
|
//
|
||||||
|
// 這個是 B5 最硬的 integration test — 涵蓋 storage presigned URL、model repo、
|
||||||
|
// 兩階段上傳的 handshake。
|
||||||
|
func TestB5_ModelUploadFlow(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 1. POST /api/models/init
|
||||||
|
initBody := map[string]any{
|
||||||
|
"name": "YOLOv5 Test",
|
||||||
|
"file_size": 11,
|
||||||
|
"target_chip": "kl520",
|
||||||
|
}
|
||||||
|
initBytes, _ := json.Marshal(initBody)
|
||||||
|
initResp, err := client.Post(f.apiServer.URL+"/api/models/init", "application/json", bytes.NewReader(initBytes))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, initResp.StatusCode, "init 應成功")
|
||||||
|
|
||||||
|
var initRespBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(initResp.Body).Decode(&initRespBody))
|
||||||
|
initResp.Body.Close()
|
||||||
|
initData := initRespBody["data"].(map[string]any)
|
||||||
|
modelID := initData["model_id"].(string)
|
||||||
|
uploadURL := initData["upload_url"].(string)
|
||||||
|
require.NotEmpty(t, modelID)
|
||||||
|
require.NotEmpty(t, uploadURL)
|
||||||
|
|
||||||
|
// 2. PUT 上傳檔案 — 直接用 init 回來的 upload_url(setupFixture 已把 storage.baseURL
|
||||||
|
// 指向 apiServer.URL+"/storage",所以 upload_url 已是可用的完整 URL)。
|
||||||
|
// PUT /storage/* 不走 AuthMiddleware(HMAC 簽章),用 default client 即可。
|
||||||
|
_ = initData["storage_key"] // 保留變數供未來驗證
|
||||||
|
payload := []byte("hello world") // 11 bytes 對上 file_size
|
||||||
|
|
||||||
|
putReq, err := http.NewRequest(http.MethodPut, uploadURL, bytes.NewReader(payload))
|
||||||
|
require.NoError(t, err)
|
||||||
|
putReq.ContentLength = int64(len(payload))
|
||||||
|
putResp, err := http.DefaultClient.Do(putReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer putResp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusNoContent, putResp.StatusCode, "PUT 應該 204")
|
||||||
|
|
||||||
|
// 3. POST /api/models/:id/finalize
|
||||||
|
finalizeResp, err := client.Post(f.apiServer.URL+"/api/models/"+modelID+"/finalize", "application/json", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer finalizeResp.Body.Close()
|
||||||
|
// 注意:不要在 require 的錯誤訊息中 readBody() — 那會消耗 body 導致後面 Decode EOF
|
||||||
|
require.Equal(t, http.StatusOK, finalizeResp.StatusCode, "finalize 應成功")
|
||||||
|
|
||||||
|
var fbody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(finalizeResp.Body).Decode(&fbody))
|
||||||
|
fdata := fbody["data"].(map[string]any)
|
||||||
|
assert.Equal(t, "ready", fdata["status"])
|
||||||
|
assert.Equal(t, modelID, fdata["id"])
|
||||||
|
|
||||||
|
// 4. GET /api/models — 應該看到我們上傳的
|
||||||
|
listResp, err := client.Get(f.apiServer.URL + "/api/models")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer listResp.Body.Close()
|
||||||
|
var lbody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(listResp.Body).Decode(&lbody))
|
||||||
|
arr := lbody["data"].([]any)
|
||||||
|
assert.GreaterOrEqual(t, len(arr), 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_ModelInit_TooLarge 驗證超過 MaxUploadSizeMB 回 413。
|
||||||
|
func TestB5_ModelInit_TooLarge(t *testing.T) {
|
||||||
|
// 自行 spin fixture 並設一個很小的 max size
|
||||||
|
f := setupFixtureWithMaxUpload(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}), 1) // 1 MB
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
initBody := map[string]any{
|
||||||
|
"name": "too big",
|
||||||
|
"file_size": int64(2) * 1024 * 1024, // 2 MB > 1 MB limit
|
||||||
|
}
|
||||||
|
initBytes, _ := json.Marshal(initBody)
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/models/init", "application/json", bytes.NewReader(initBytes))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusRequestEntityTooLarge, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
errObj := body["error"].(map[string]any)
|
||||||
|
assert.Equal(t, "PAYLOAD_TOO_LARGE", errObj["code"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_StoragePutDirect 先單獨驗證 /storage/* PUT 走得通(排除 model upload 流程變數)。
|
||||||
|
func TestB5_StoragePutDirect(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 先 init 一個 model 取得 upload_url
|
||||||
|
initBytes, _ := json.Marshal(map[string]any{"name": "x", "file_size": 11})
|
||||||
|
initResp, _ := client.Post(f.apiServer.URL+"/api/models/init", "application/json", bytes.NewReader(initBytes))
|
||||||
|
var ib map[string]any
|
||||||
|
_ = json.NewDecoder(initResp.Body).Decode(&ib)
|
||||||
|
initResp.Body.Close()
|
||||||
|
uploadURL := ib["data"].(map[string]any)["upload_url"].(string)
|
||||||
|
t.Logf("uploadURL=%s", uploadURL)
|
||||||
|
|
||||||
|
// PUT /storage/* 不走 AuthMiddleware(HMAC 簽章),用 default client 即可
|
||||||
|
putReq, _ := http.NewRequest(http.MethodPut, uploadURL, bytes.NewReader([]byte("hello world")))
|
||||||
|
putReq.ContentLength = 11
|
||||||
|
resp, err := http.DefaultClient.Do(putReq)
|
||||||
|
require.NoError(t, err, "err=%v", err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNoContent, resp.StatusCode, "body=%s", readBody(resp.Body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestB5_ClustersList 驗證 GET /api/clusters 回空陣列(雛形)。
|
||||||
|
func TestB5_ClustersList(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/clusters")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, true, body["success"])
|
||||||
|
arr := body["data"].([]any)
|
||||||
|
assert.Empty(t, arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// readBody 非 destructive 讀 response body(供 failure message 用)。
|
||||||
|
func readBody(r io.Reader) string {
|
||||||
|
b, _ := io.ReadAll(r)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
289
visionA-backend/cmd/api-server/e2e_full_flow_test.go
Normal file
289
visionA-backend/cmd/api-server/e2e_full_flow_test.go
Normal file
@ -0,0 +1,289 @@
|
|||||||
|
// e2e_full_flow_test.go — AB13:把 pairing exchange + tunnel connect + API forward
|
||||||
|
// 三段串成單一端到端測試,驗證雲端版完整鏈路。
|
||||||
|
//
|
||||||
|
// 這個 test 是雛形交付前的最終驗收 — 通過代表:
|
||||||
|
//
|
||||||
|
// 使用者在 agent 貼上 pairing token
|
||||||
|
// │
|
||||||
|
// ▼
|
||||||
|
// Agent 呼叫 POST /api/pairing/exchange(api-server)
|
||||||
|
// → 拿到 Session Token + Relay URL
|
||||||
|
// ▼
|
||||||
|
// Agent 用 Session Token 對 remote-proxy 的 /tunnel/connect 建 WebSocket
|
||||||
|
// → yamux session 註冊進 SessionStore
|
||||||
|
// ▼
|
||||||
|
// 前端打 GET /api/devices/scan(api-server)
|
||||||
|
// → api-server.Forwarder.ForwardHTTP
|
||||||
|
// → remote-proxy.handleInternalForward (hijack + yamux OpenStream)
|
||||||
|
// → agent.handleStream (RoundTrip to local-tool)
|
||||||
|
// → local-tool 回 JSON
|
||||||
|
// → 逐段原封轉回前端
|
||||||
|
//
|
||||||
|
// 設計取捨:
|
||||||
|
//
|
||||||
|
// - 為何不 cross-module import agent 原始碼:
|
||||||
|
// visionA-backend 與 visiona-agent 是獨立 go module,agent 又依賴 wails/v2
|
||||||
|
// (會把 Wails 的 UI 層傳遞依賴全拖進 backend 的 go.sum)。為了保持 backend
|
||||||
|
// 的依賴乾淨,我們用 b5_integration_test.go 早已驗證過的 startFakeTunnelClient
|
||||||
|
// —— 它用純 gorilla/websocket + yamux.Client 重現 agent 的 tunnel 邏輯,在
|
||||||
|
// 協議面上與 agent 的 tunnel.Client 等價(agent 的 Client 也是在 WS 上跑
|
||||||
|
// yamux.Client,handleStream 用 http.ReadRequest / RoundTrip / resp.Write)。
|
||||||
|
//
|
||||||
|
// 真正的 agent 程式碼路徑(tunnel.Client.handleStream)已由 AB6 的
|
||||||
|
// internal/tunnel/integration_test.go 用同樣的 fake relay 模式驗證過;
|
||||||
|
// 那邊用的是真 Manager + fake relay,這邊用的是真 backend + fake tunnel
|
||||||
|
// client。兩者覆蓋的是鏡像路徑,合起來 = 完整 e2e。
|
||||||
|
//
|
||||||
|
// - 為何不 spawn subprocess:go test 裡 exec.Command("go", "run", ...) 在 CI
|
||||||
|
// 上不穩(port 競爭、cleanup race、跨 module build),且測試時間會從秒級
|
||||||
|
// 拉到分鐘級。subprocess 方案我們另外提供為 manual script(scripts/
|
||||||
|
// e2e-manual-test.sh),給使用者要真驗證時跑。
|
||||||
|
//
|
||||||
|
// 參考:
|
||||||
|
// - .autoflow/04-architecture/visiona-agent-tdd.md §11(integration / e2e testing)
|
||||||
|
// - .autoflow/04-architecture/tunnel.md §3(資料流)
|
||||||
|
// - b5_integration_test.go / pairing_exchange_test.go(既有整合測試基礎)
|
||||||
|
// - local-agent/visiona-agent/internal/tunnel/integration_test.go(agent 端鏡像測試)
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/api"
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestE2E_FullFlow_PairingToForward 是 AB13 的核心驗收測試。
|
||||||
|
//
|
||||||
|
// 串起 pairing exchange → tunnel connect → API forward 三個里程碑,確認
|
||||||
|
// 整條雲端版架構在同一個 test run 裡能跑通。
|
||||||
|
//
|
||||||
|
// 這個 test 跟 TestAB11_PairingExchange_EndToEnd 的差別:AB11 驗到 tunnel
|
||||||
|
// connect 進 store 就停,這裡往下多走一段「打 API → forward 回 fake local」,
|
||||||
|
// 覆蓋 B5 forwarder handler 真實被呼叫的路徑。
|
||||||
|
func TestE2E_FullFlow_PairingToForward(t *testing.T) {
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 1. fake local-tool(模擬 agent 背後的 local HTTP server)
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 對 /api/devices/scan 回一段 JSON,驗證 request 真的穿過整條 tunnel。
|
||||||
|
// 同時 echo 出 X-Forwarded-For / X-Request-ID 之類的 header 供驗證。
|
||||||
|
localCalls := 0
|
||||||
|
localHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
localCalls++
|
||||||
|
// 驗證進來的 request 是我們預期的(證明 host rewrite 正確)
|
||||||
|
if r.URL.Path != "/api/devices/scan" {
|
||||||
|
http.Error(w, "unexpected path: "+r.URL.Path, http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
http.Error(w, "unexpected method: "+r.Method, http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Backend-Source", "e2e-fake-local")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"scanned": 2,
|
||||||
|
"devices": []map[string]any{
|
||||||
|
{"id": "kl520-e2e-01", "type": "kl520", "status": "online"},
|
||||||
|
{"id": "kl730-e2e-02", "type": "kl730", "status": "online"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
f := setupFixture(t, localHandler)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
authClient := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Milestone 1: Pairing Exchange
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 1a. 產 Pairing Token(OIDC cookie 放行 AuthMiddleware)
|
||||||
|
tokResp, err := authClient.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tokResp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, tokResp.StatusCode)
|
||||||
|
|
||||||
|
var tokBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(tokResp.Body).Decode(&tokBody))
|
||||||
|
pairingTok := tokBody["data"].(map[string]any)["token"].(string)
|
||||||
|
require.True(t, auth.IsValidPairingToken(pairingTok),
|
||||||
|
"Milestone 1a: pairing token 格式應合法,實得 %q", pairingTok)
|
||||||
|
|
||||||
|
// 1b. 用 Pairing Token 換 Session Token(不走 AuthMiddleware)
|
||||||
|
exchBody, _ := json.Marshal(api.PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
exchResp, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json", bytes.NewReader(exchBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer exchResp.Body.Close()
|
||||||
|
exchRaw, _ := io.ReadAll(exchResp.Body)
|
||||||
|
require.Equal(t, http.StatusOK, exchResp.StatusCode,
|
||||||
|
"Milestone 1b: exchange 應成功,body: %s", string(exchRaw))
|
||||||
|
|
||||||
|
var exchBodyDecoded map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(exchRaw, &exchBodyDecoded))
|
||||||
|
exchData := exchBodyDecoded["data"].(map[string]any)
|
||||||
|
sessionTok := exchData["session_token"].(string)
|
||||||
|
require.True(t, auth.IsValidSessionToken(sessionTok),
|
||||||
|
"Milestone 1b: session token 格式應合法,實得 %q", sessionTok)
|
||||||
|
assert.NotEmpty(t, exchData["relay_url"], "relay_url 應由 api-server 回傳給 agent")
|
||||||
|
assert.NotEmpty(t, exchData["account"], "account 應由 api-server 回傳給 agent")
|
||||||
|
assert.NotEmpty(t, exchData["expires_at"], "expires_at 應由 api-server 回傳給 agent")
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Milestone 2: Tunnel Connect
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 用剛換到的 Session Token 對 remote-proxy 的 tunnel endpoint 建連線。
|
||||||
|
// startFakeTunnelClient 的協議行為 = agent 的 tunnel.Client(WS+yamux+handleStream)。
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, sessionTok,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// 等 session 實際註冊進 remote-proxy 的 InMemoryStore
|
||||||
|
// (WS handshake + yamux client up 是非同步的)
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
ok, _ := f.store.Exists(ctx, sessionTok)
|
||||||
|
return ok
|
||||||
|
}, 3*time.Second, 20*time.Millisecond,
|
||||||
|
"Milestone 2: session token 應在 3 秒內出現在 SessionStore")
|
||||||
|
|
||||||
|
// 2b. 同步驗證 /api/system/health 看 api-server 也能透過 ProxyClient 讀到 session
|
||||||
|
healthResp, err := authClient.Get(f.apiServer.URL + "/api/system/health")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer healthResp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, healthResp.StatusCode)
|
||||||
|
var healthBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(healthResp.Body).Decode(&healthBody))
|
||||||
|
healthData := healthBody["data"].(map[string]any)
|
||||||
|
assert.Equal(t, true, healthData["tunnel_connected"],
|
||||||
|
"Milestone 2b: /api/system/health 應回 tunnel_connected=true(api-server → remote-proxy 讀 session)")
|
||||||
|
assert.EqualValues(t, 1, healthData["agent_session_count"],
|
||||||
|
"Milestone 2b: 應有 1 個 agent session")
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Milestone 3: API Forward(完整鏈路)
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 這是最終驗證:browser → api-server → Forwarder → remote-proxy
|
||||||
|
// → yamux stream → fake tunnel client → fake local-tool
|
||||||
|
// 任何一環出錯都會 fail。
|
||||||
|
scanResp, err := authClient.Post(f.apiServer.URL+"/api/devices/scan", "application/json", nil)
|
||||||
|
require.NoError(t, err, "Milestone 3: scan 應能 forward 成功")
|
||||||
|
defer scanResp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, scanResp.StatusCode,
|
||||||
|
"Milestone 3: scan 應回 200,實際 %d", scanResp.StatusCode)
|
||||||
|
|
||||||
|
var scanBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(scanResp.Body).Decode(&scanBody))
|
||||||
|
assert.EqualValues(t, 2, scanBody["scanned"],
|
||||||
|
"Milestone 3: response body 應原封穿過 tunnel 回來")
|
||||||
|
devices := scanBody["devices"].([]any)
|
||||||
|
require.Len(t, devices, 2, "Milestone 3: 應收到 2 個 device")
|
||||||
|
// 驗證 header 也穿過來了(Forwarder 會保留 upstream response header)
|
||||||
|
assert.Equal(t, "e2e-fake-local", scanResp.Header.Get("X-Backend-Source"),
|
||||||
|
"Milestone 3: fake local 的 response header 應被轉回來")
|
||||||
|
assert.Equal(t, 1, localCalls,
|
||||||
|
"Milestone 3: fake local 應被呼叫一次(證明 request 真的走到底)")
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// Milestone 4: 重複兌換保護
|
||||||
|
// -----------------------------------------------------------------
|
||||||
|
// 同一個 pairing token 再換一次應該被拒絕(PAIRING_TOKEN_USED),
|
||||||
|
// 避免有人竊聽到 pairing token 時能重放。
|
||||||
|
exchResp2, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json", bytes.NewReader(exchBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer exchResp2.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, exchResp2.StatusCode,
|
||||||
|
"Milestone 4: 重複兌換應 401")
|
||||||
|
body2, _ := io.ReadAll(exchResp2.Body)
|
||||||
|
assert.Contains(t, string(body2), "PAIRING_TOKEN_USED",
|
||||||
|
"Milestone 4: 錯誤碼應為 PAIRING_TOKEN_USED")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestE2E_ForwardFailsWhenTunnelDropped 驗證 tunnel 斷線後 API forward 會正確
|
||||||
|
// 回 502(TUNNEL_DISCONNECTED)。這模擬 agent 端進程崩潰 / 網路中斷後雲端的
|
||||||
|
// 反應,對齊 TDD §11.2 failure mode 清單。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 建立完整 e2e 鏈路(exchange → connect → forward 一次成功)
|
||||||
|
// 2. 關掉 fake tunnel client(模擬 agent 崩潰)
|
||||||
|
// 3. 等 session 從 store 消失
|
||||||
|
// 4. 再打 /api/devices/scan → 預期 502 TUNNEL_DISCONNECTED
|
||||||
|
func TestE2E_ForwardFailsWhenTunnelDropped(t *testing.T) {
|
||||||
|
localHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"scanned": 0})
|
||||||
|
})
|
||||||
|
|
||||||
|
f := setupFixture(t, localHandler)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
authClient := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 1. Exchange → Session Token
|
||||||
|
tokResp, err := authClient.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var tokBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(tokResp.Body).Decode(&tokBody))
|
||||||
|
tokResp.Body.Close()
|
||||||
|
pairingTok := tokBody["data"].(map[string]any)["token"].(string)
|
||||||
|
|
||||||
|
exchReqBody, _ := json.Marshal(api.PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
exchResp, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json", bytes.NewReader(exchReqBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
var exchBodyDecoded map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(exchResp.Body).Decode(&exchBodyDecoded))
|
||||||
|
exchResp.Body.Close()
|
||||||
|
sessionTok := exchBodyDecoded["data"].(map[string]any)["session_token"].(string)
|
||||||
|
|
||||||
|
// 2. 建 tunnel + 先 forward 一次確認鏈路通
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, sessionTok,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), sessionTok)
|
||||||
|
return ok
|
||||||
|
}, 3*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
firstResp, err := authClient.Post(f.apiServer.URL+"/api/devices/scan", "application/json", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, http.StatusOK, firstResp.StatusCode,
|
||||||
|
"前置條件:第一次 forward 應成功,代表鏈路有建起")
|
||||||
|
firstResp.Body.Close()
|
||||||
|
|
||||||
|
// 3. 關掉 fake tunnel client(agent 崩潰)
|
||||||
|
stop()
|
||||||
|
|
||||||
|
// 等 session 被清出 store(WS close → relay 偵測到 → remove session)
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), sessionTok)
|
||||||
|
return !ok
|
||||||
|
}, 3*time.Second, 50*time.Millisecond,
|
||||||
|
"session 應在 tunnel 斷線後從 store 消失")
|
||||||
|
|
||||||
|
// 4. 再 forward 應 502
|
||||||
|
secondResp, err := authClient.Post(f.apiServer.URL+"/api/devices/scan", "application/json", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer secondResp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadGateway, secondResp.StatusCode,
|
||||||
|
"tunnel 斷線後 forward 應回 502")
|
||||||
|
|
||||||
|
var errBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(secondResp.Body).Decode(&errBody))
|
||||||
|
errObj := errBody["error"].(map[string]any)
|
||||||
|
assert.Equal(t, "TUNNEL_DISCONNECTED", errObj["code"])
|
||||||
|
}
|
||||||
526
visionA-backend/cmd/api-server/integration_test.go
Normal file
526
visionA-backend/cmd/api-server/integration_test.go
Normal file
@ -0,0 +1,526 @@
|
|||||||
|
// integration_test.go — B4 完整端對端測試。
|
||||||
|
//
|
||||||
|
// 驗證雛形雙 binary 架構能跑通:
|
||||||
|
//
|
||||||
|
// prog test (HTTP client)
|
||||||
|
// └─► api-server: GET /api/system/health, /api/pairing/status
|
||||||
|
// └─► (system/health 內部呼叫 SessionStore.List → ProxyClient → remote-proxy)
|
||||||
|
// └─► remote-proxy: /internal/sessions
|
||||||
|
// └─► InMemoryStore(fake tunnel client 已註冊一個 session)
|
||||||
|
//
|
||||||
|
// 以及驗證 forwarder 能完成 raw forward:
|
||||||
|
//
|
||||||
|
// api-server (Forwarder.OpenStream)
|
||||||
|
// └─► remote-proxy: POST /internal/forward/raw
|
||||||
|
// └─► hijack + yamux stream → fake tunnel client
|
||||||
|
// └─► fake local server 回 chunked response
|
||||||
|
//
|
||||||
|
// 這是 B4 任務最關鍵的里程碑:證明整條雲端版架構能跑。
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/api"
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/converter"
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
"visiona-backend/internal/oidc"
|
||||||
|
"visiona-backend/internal/oidctest"
|
||||||
|
"visiona-backend/internal/relay"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
"visiona-backend/internal/wsconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fixtureOIDCClientID / fixtureOIDCClientSecret 是測試用的 OIDC client 憑證。
|
||||||
|
// 與 fakeOIDC server 內 SetClientCredentials 對齊。
|
||||||
|
const (
|
||||||
|
fixtureOIDCClientID = "visiona-backend-fixture"
|
||||||
|
fixtureOIDCClientSecret = "fixture-test-secret"
|
||||||
|
fixtureSessionSecret = "fixture-session-signing-secret-32b!"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// 避免 gin debug log 汙染測試輸出
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// lazyHandler 是 swap-able 的 http.Handler 包裝,讓 fixture 可以先 Start 拿 URL,
|
||||||
|
// 再把真正的 router 裝進來(解循環依賴:storage.baseURL 需要 apiServer.URL,
|
||||||
|
// 而 storage 又是 router 的依賴)。
|
||||||
|
//
|
||||||
|
// 並發安全:Set 與 ServeHTTP 都透過 sync/atomic.Value 同步。
|
||||||
|
type lazyHandler struct {
|
||||||
|
v atomic.Value // http.Handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *lazyHandler) Set(h http.Handler) {
|
||||||
|
l.v.Store(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *lazyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
v := l.v.Load()
|
||||||
|
if v == nil {
|
||||||
|
http.Error(w, "router not initialized", http.StatusServiceUnavailable)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
v.(http.Handler).ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// testFixture 把 integration test 需要的所有 server 集中管理,方便 setup/teardown。
|
||||||
|
type testFixture struct {
|
||||||
|
apiServer *httptest.Server
|
||||||
|
internalSrv *httptest.Server
|
||||||
|
tunnelSrv *httptest.Server
|
||||||
|
localBackend *httptest.Server
|
||||||
|
store *session.InMemoryStore
|
||||||
|
forwarder *session.Forwarder
|
||||||
|
|
||||||
|
// fakeOIDC 是 OB5 起新增 — 所有 fixture 內建一個 fake OIDC server,
|
||||||
|
// 讓需要走 AuthMiddleware 的 test 可以用 AuthenticatedClient 一鍵完成登入流程。
|
||||||
|
fakeOIDC *oidctest.Server
|
||||||
|
|
||||||
|
// pairingStore / sessionMgr 暴露給少數需要直接操作 store 的 test。
|
||||||
|
pairingStore *auth.InMemoryPairingStore
|
||||||
|
sessionMgr *usersession.Manager
|
||||||
|
|
||||||
|
// router 暴露 *gin.Engine 給需要列出所有 route 的 test
|
||||||
|
// (目前用於 all_endpoints_require_auth_test.go — Phase 0.7 security regression)。
|
||||||
|
router *gin.Engine
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *testFixture) Close() {
|
||||||
|
if f.apiServer != nil {
|
||||||
|
f.apiServer.Close()
|
||||||
|
}
|
||||||
|
if f.internalSrv != nil {
|
||||||
|
f.internalSrv.Close()
|
||||||
|
}
|
||||||
|
if f.tunnelSrv != nil {
|
||||||
|
f.tunnelSrv.Close()
|
||||||
|
}
|
||||||
|
if f.localBackend != nil {
|
||||||
|
f.localBackend.Close()
|
||||||
|
}
|
||||||
|
// fakeOIDC 用 t.Cleanup 自動關(NewServer 內已註冊),這裡不需手動。
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupFixture 啟動完整的 5 段架構:
|
||||||
|
// - localBackend:扮演 local-tool(127.0.0.1:3721)
|
||||||
|
// - tunnel server:remote-proxy 對 local agent 的 WS port
|
||||||
|
// - internal server:remote-proxy 對 api-server 的 internal HTTP port
|
||||||
|
// - api-server:給前端用的 REST/Gin
|
||||||
|
//
|
||||||
|
// 注意:fake tunnel client 沒在這裡 spawn,因為各 test case 對 token 的需求不同。
|
||||||
|
func setupFixture(t *testing.T, localHandler http.Handler) *testFixture {
|
||||||
|
return setupFixtureWithMaxUpload(t, localHandler, 0) // 0 = 不限
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupFixtureWithMaxUpload 同 setupFixture 但可設定 MaxUploadSizeMB;
|
||||||
|
// B5 的 model-too-large test 需要這個。
|
||||||
|
//
|
||||||
|
// 另一個微妙的差別:storage 的 baseURL 設為 apiServer.URL + "/storage",
|
||||||
|
// 這樣 PUT /storage/{key} 的 presigned URL 能被同一個 apiServer 處理,
|
||||||
|
// b5_integration_test.go 的上傳流程才能端對端跑通。
|
||||||
|
//
|
||||||
|
// OB5 起內建 fake OIDC server + OIDC wiring(OIDC 是唯一認證路徑)。
|
||||||
|
// 走 AuthMiddleware 的 test 應透過 fixture.AuthenticatedClient 完成登入。
|
||||||
|
func setupFixtureWithMaxUpload(t *testing.T, localHandler http.Handler, maxUploadMB int) *testFixture {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// 1. fake local-tool (127.0.0.1:3721 模擬)
|
||||||
|
localBackend := httptest.NewServer(localHandler)
|
||||||
|
|
||||||
|
// 2. remote-proxy
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelWarn}))
|
||||||
|
relaySrv := relay.NewServer(store, logger, relay.Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
internalSrv := relay.NewInternalServer(store, logger)
|
||||||
|
|
||||||
|
// tunnel mux (面向 fake local agent)
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect)
|
||||||
|
tunnelTS := httptest.NewServer(tunnelMux)
|
||||||
|
|
||||||
|
// internal mux (面向 api-server)
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(internalMux)
|
||||||
|
internalTS := httptest.NewServer(internalMux)
|
||||||
|
|
||||||
|
// 3. api-server — 透過 ProxyClient/Forwarder 指向上面的 internalTS
|
||||||
|
proxyClient := session.NewHTTPProxyClient(internalTS.URL, logger)
|
||||||
|
forwarder := session.NewForwarder(internalTS.URL, logger)
|
||||||
|
sessionStore := session.NewProxyClientStore(proxyClient, forwarder)
|
||||||
|
|
||||||
|
// 需要先知道 api-server URL 才能建 storage(presigned URL 的 baseURL),
|
||||||
|
// 但 storage 又是 router 的依賴。解法:用 lazyHandler — 一個可以被 swap 的
|
||||||
|
// http.Handler,讓我們先 Start server 拿 URL,再把真正的 router 裝進去。
|
||||||
|
storeDir := t.TempDir()
|
||||||
|
|
||||||
|
lazy := &lazyHandler{}
|
||||||
|
apiTS := httptest.NewServer(lazy)
|
||||||
|
storeStore, err := storage.NewLocalFSStore(storeDir, apiTS.URL+"/storage", "test-secret")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 4. fake OIDC + OIDC client(OB5:唯一認證路徑)
|
||||||
|
fakeOIDC := oidctest.NewServer(t,
|
||||||
|
oidctest.WithClientCredentials(fixtureOIDCClientID, fixtureOIDCClientSecret),
|
||||||
|
)
|
||||||
|
|
||||||
|
callbackURL := apiTS.URL + "/api/auth/callback"
|
||||||
|
oidcCtx, oidcCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
oidcProvider, err := oidc.NewProvider(oidcCtx, oidc.ProviderConfig{
|
||||||
|
IssuerURL: fakeOIDC.URL,
|
||||||
|
ClientID: fakeOIDC.ClientID,
|
||||||
|
ClientSecret: fakeOIDC.ClientSecret,
|
||||||
|
RedirectURL: callbackURL,
|
||||||
|
})
|
||||||
|
oidcCancel()
|
||||||
|
require.NoError(t, err, "fixture: OIDC provider init failed")
|
||||||
|
|
||||||
|
sessionMgr := usersession.NewManager(usersession.NewInMemoryStore(), usersession.CookieConfig{
|
||||||
|
Name: "visiona_session",
|
||||||
|
Path: "/",
|
||||||
|
HTTPOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: 86400,
|
||||||
|
SigningKey: []byte(fixtureSessionSecret),
|
||||||
|
})
|
||||||
|
|
||||||
|
pairingStore := auth.NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
router := api.NewRouter(api.Deps{
|
||||||
|
Logger: logger,
|
||||||
|
PairingStore: pairingStore,
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
SessionStore: sessionStore,
|
||||||
|
Forwarder: forwarder,
|
||||||
|
DeviceRepo: device.NewInMemoryRepository(),
|
||||||
|
ModelRepo: model.NewInMemoryRepository(),
|
||||||
|
Storage: storeStore,
|
||||||
|
Converter: converter.NewStubClient(),
|
||||||
|
// Phase 0.7 security fix C1:StaticUserID 已從 Deps 移除(見 internal/api/api.go:77-80 註解)。
|
||||||
|
// 整合測試走 fixture.AuthenticatedClient 完整 OIDC login flow 取 cookie,不再走 fallback 捷徑。
|
||||||
|
MaxUploadSizeMB: maxUploadMB,
|
||||||
|
RelayPublicURL: tunnelTS.URL, // 讓 exchange 測試能拿到真實 tunnel URL
|
||||||
|
|
||||||
|
// OIDC wiring(OB5)
|
||||||
|
OIDCProvider: oidcProvider,
|
||||||
|
SessionManager: sessionMgr,
|
||||||
|
OIDCPostLoginURL: apiTS.URL, // 把 frontend redirect 收回自己,方便測試 follow up
|
||||||
|
})
|
||||||
|
lazy.Set(router)
|
||||||
|
|
||||||
|
return &testFixture{
|
||||||
|
apiServer: apiTS,
|
||||||
|
internalSrv: internalTS,
|
||||||
|
tunnelSrv: tunnelTS,
|
||||||
|
localBackend: localBackend,
|
||||||
|
store: store,
|
||||||
|
forwarder: forwarder,
|
||||||
|
fakeOIDC: fakeOIDC,
|
||||||
|
pairingStore: pairingStore,
|
||||||
|
sessionMgr: sessionMgr,
|
||||||
|
router: router,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// startFakeTunnelClient 模擬 local agent:
|
||||||
|
// - 對 tunnel server 開 WebSocket 上 yamux client
|
||||||
|
// - 對每條 stream 用 http.ReadRequest 解出 request → 真 TCP 轉發到 localAddr
|
||||||
|
func startFakeTunnelClient(t *testing.T, tunnelHTTPURL, token, localAddr string) (stop func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(tunnelHTTPURL, "http") + "/tunnel/connect"
|
||||||
|
u, err := url.Parse(wsURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("token", token)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
rawWS, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
require.NoError(t, err, "fake tunnel client failed to dial")
|
||||||
|
|
||||||
|
netConn := wsconn.New(rawWS)
|
||||||
|
ym, err := yamux.Client(netConn, yamux.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
stream, aerr := ym.Accept()
|
||||||
|
if aerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func(s net.Conn) {
|
||||||
|
defer s.Close()
|
||||||
|
req, rerr := http.ReadRequest(bufio.NewReader(s))
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 改寫 req:scheme/host 指向 fake localBackend,重設 RequestURI
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
req.URL.Host = localAddr
|
||||||
|
req.Host = localAddr
|
||||||
|
req.RequestURI = ""
|
||||||
|
|
||||||
|
resp, rerr := http.DefaultTransport.RoundTrip(req)
|
||||||
|
if rerr != nil {
|
||||||
|
_ = (&http.Response{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ProtoMajor: 1, ProtoMinor: 1,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||||
|
}).Write(s)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_ = resp.Write(s)
|
||||||
|
}(stream)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
_ = ym.Close()
|
||||||
|
_ = rawWS.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Test cases
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// TestIntegration_HealthEndpoint 驗證 /healthz(不需 auth)+ /api/system/info(OIDC)都能 200。
|
||||||
|
func TestIntegration_HealthEndpoint(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// /healthz:不走 AuthMiddleware
|
||||||
|
resp, err := http.Get(f.apiServer.URL + "/healthz")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode, "path=/healthz")
|
||||||
|
|
||||||
|
// /api/system/info:走 AuthMiddleware → 需要 cookie
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp2, err := client.Get(f.apiServer.URL + "/api/system/info")
|
||||||
|
require.NoError(t, err)
|
||||||
|
resp2.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusOK, resp2.StatusCode, "path=/api/system/info")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_SystemHealth_NoTunnel 驗證沒 tunnel 時,
|
||||||
|
// /api/system/health 回 connected=false(且整條 ProxyClient → remote-proxy 路徑通)。
|
||||||
|
func TestIntegration_SystemHealth_NoTunnel(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/system/health")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
require.Equal(t, true, body["success"])
|
||||||
|
data := body["data"].(map[string]any)
|
||||||
|
assert.Equal(t, "ok", data["api_server"])
|
||||||
|
assert.Equal(t, false, data["tunnel_connected"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_SystemHealth_WithTunnel 驗證有 tunnel 時,
|
||||||
|
// /api/system/health 回 connected=true(證明整條 api-server → ProxyClient
|
||||||
|
// → remote-proxy → InMemoryStore 鏈路正常)。
|
||||||
|
func TestIntegration_SystemHealth_WithTunnel(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
const token = "vAc_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, token,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/system/health")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
data := body["data"].(map[string]any)
|
||||||
|
assert.Equal(t, true, data["tunnel_connected"], "預期 tunnel_connected=true,實際 body=%v", body)
|
||||||
|
assert.EqualValues(t, 1, data["agent_session_count"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_PairingTokenCreate 驗證 POST /api/pairing/token 能成功建立 token,
|
||||||
|
// 且回傳的 token 之後可以拿來連 tunnel(端到端覆蓋整條 pairing → tunnel 流程)。
|
||||||
|
func TestIntegration_PairingTokenCreate(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 1. POST /api/pairing/token
|
||||||
|
resp, err := client.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
data := body["data"].(map[string]any)
|
||||||
|
tok, _ := data["token"].(string)
|
||||||
|
require.True(t, auth.IsValidPairingToken(tok), "token 應為合法 pairing 格式:%s", tok)
|
||||||
|
|
||||||
|
// 2. 拿這個 token 連 tunnel — 應該成功
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, tok,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), tok)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond, "新 pairing token 應能成功註冊 tunnel session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Forwarder_EndToEnd 是 B4 的核心驗證:
|
||||||
|
//
|
||||||
|
// 走完一整條
|
||||||
|
//
|
||||||
|
// api-server (Forwarder.OpenStream)
|
||||||
|
// └─► raw TCP → remote-proxy: POST /internal/forward/raw
|
||||||
|
// └─► hijack + OpenStream 走 yamux
|
||||||
|
// └─► fake tunnel client 收到 stream
|
||||||
|
// └─► 真 TCP forward 到 fake localBackend
|
||||||
|
// └─► local 回 200 + JSON body
|
||||||
|
//
|
||||||
|
// 這個測試證明 B4 完整 forwarder 鏈路可運作;B5 接 handler 時可以放心呼叫
|
||||||
|
// Forwarder.ForwardHTTP 而不必再驗證底層。
|
||||||
|
func TestIntegration_Forwarder_EndToEnd(t *testing.T) {
|
||||||
|
const expectedRoute = "/api/devices"
|
||||||
|
const expectedHeader = "from-api-server"
|
||||||
|
|
||||||
|
// fake localBackend 收到 / 後回一段 JSON
|
||||||
|
localHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Echo-Path", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"method": r.Method,
|
||||||
|
"path": r.URL.Path,
|
||||||
|
"echo_header": r.Header.Get("X-From-Api"),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
f := setupFixture(t, localHandler)
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
const token = "vAc_bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, token,
|
||||||
|
strings.TrimPrefix(f.localBackend.URL, "http://"))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := f.store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 用 forwarder 直接 ForwardHTTP(模擬 B5 的 handler 會做的事)
|
||||||
|
req, err := http.NewRequest(http.MethodGet, expectedRoute, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set("X-From-Api", expectedHeader)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
resp, err := f.forwarder.ForwardHTTP(ctx, token, req)
|
||||||
|
require.NoError(t, err, "Forwarder.ForwardHTTP 應該成功")
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Equal(t, expectedRoute, resp.Header.Get("X-Echo-Path"))
|
||||||
|
|
||||||
|
var payload map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&payload))
|
||||||
|
assert.Equal(t, "GET", payload["method"])
|
||||||
|
assert.Equal(t, expectedRoute, payload["path"])
|
||||||
|
assert.Equal(t, expectedHeader, payload["echo_header"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Forwarder_TunnelDisconnected 驗證當 token 不存在時,
|
||||||
|
// Forwarder.OpenStream 回 ErrSessionNotFound(讓 caller handler 可以轉 502)。
|
||||||
|
func TestIntegration_Forwarder_TunnelDisconnected(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
_, err := f.forwarder.OpenStream(ctx, "vAc_doesnotexistdoesnotexistdoesno")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, session.ErrSessionNotFound,
|
||||||
|
"預期 ErrSessionNotFound,實際 err=%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_Stub_NotImplemented 驗證 B5/B7 仍未補的 endpoint 確實回 501。
|
||||||
|
//
|
||||||
|
// B5 後 /api/devices 已改為實際 handler(回空陣列),所以改驗 /api/converter/jobs —
|
||||||
|
// 那個是 Phase 1 才會做、目前仍為 stub。
|
||||||
|
func TestIntegration_Stub_NotImplemented(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
resp, err := client.Get(f.apiServer.URL + "/api/converter/jobs")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNotImplemented, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestIntegration_CORS_Preflight 驗證 CORS preflight 對 localhost:3000 放行。
|
||||||
|
func TestIntegration_CORS_Preflight(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodOptions, f.apiServer.URL+"/api/system/health", nil)
|
||||||
|
req.Header.Set("Origin", "http://localhost:3000")
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Contains(t, resp.Header.Get("Access-Control-Allow-Origin"), "localhost:3000",
|
||||||
|
"預期 Allow-Origin 帶 localhost:3000,實際 header: %v", resp.Header)
|
||||||
|
}
|
||||||
249
visionA-backend/cmd/api-server/main.go
Normal file
249
visionA-backend/cmd/api-server/main.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
// Command api-server 是 visionA-backend 的對前端 REST + WebSocket 伺服器。
|
||||||
|
//
|
||||||
|
// 雛形雙 binary 架構(Q1 裁決):
|
||||||
|
// - api-server **無狀態**:所有 session 狀態都在 remote-proxy 那邊
|
||||||
|
// - 透過 ProxyClientStore + Forwarder 走 internal HTTP 跟 remote-proxy 溝通
|
||||||
|
//
|
||||||
|
// 對應文件:
|
||||||
|
// - `.autoflow/04-architecture/TDD.md` §2.4(雙 binary 部署)/ §10(前端資料流)
|
||||||
|
// - `.autoflow/04-architecture/api/api-spec.md`(前端用的 REST API)
|
||||||
|
// - `.autoflow/04-architecture/api/api-internal.md`(api-server ↔ remote-proxy)
|
||||||
|
// - `.autoflow/04-architecture/tunnel.md` §5
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/api"
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/config"
|
||||||
|
"visiona-backend/internal/converter"
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
"visiona-backend/internal/logger"
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
"visiona-backend/internal/oidc"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultSigningSecret 與 config/load.go 保持一致 — 用於啟動警告。
|
||||||
|
const defaultSigningSecret = "dev-signing-secret-do-not-use-in-prod"
|
||||||
|
|
||||||
|
// shutdownTimeout 是收到 SIGINT/SIGTERM 後等待進行中請求完成的最長時間。
|
||||||
|
const shutdownTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// sessionCleanupInterval 是 OIDC user session store 的後台清理頻率。
|
||||||
|
// 設 5 分鐘是 dev / prod 都合理的值:足夠頻繁讓 idle session 不久留,
|
||||||
|
// 又不會過度消耗 CPU。
|
||||||
|
const sessionCleanupInterval = 5 * time.Minute
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config.Load()
|
||||||
|
log := logger.New(cfg.Logger.Level).With("service", "api-server")
|
||||||
|
|
||||||
|
// Validate config(特別是 OIDC 啟用時的必填欄位)。
|
||||||
|
if err := cfg.Validate(); err != nil {
|
||||||
|
log.Error("invalid config", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 啟動警告:signing secret 為預設值(同 remote-proxy 行為)。
|
||||||
|
// 此 secret 同時給 storage presigned URL 與(未來)pairing token hash 用。
|
||||||
|
if cfg.Auth.SigningSecret == defaultSigningSecret {
|
||||||
|
log.Warn("signing secret 仍為預設 dev 值(storage/pairing 共用)",
|
||||||
|
"action", "請在生產環境設定環境變數 VISIONA_STORAGE_SIGNING_SECRET",
|
||||||
|
"affects", "storage presigned URL, pairing token hash (phase 1)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== Storage =====
|
||||||
|
// 用 LocalFS(Phase 0 雛形);signing secret 共用同一份。
|
||||||
|
storageStore, err := storage.NewLocalFSStore(cfg.Storage.RootDir, cfg.Storage.BaseURL, cfg.Auth.SigningSecret)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to init storage", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
log.Info("storage initialized",
|
||||||
|
"backend", cfg.Storage.Backend,
|
||||||
|
"root", cfg.Storage.RootDir,
|
||||||
|
"base_url", cfg.Storage.BaseURL)
|
||||||
|
|
||||||
|
// ===== Pairing / Session Token(OIDC 之外的雛形 token store) =====
|
||||||
|
pairingStore := auth.NewInMemoryPairingStore()
|
||||||
|
sessionTokenStore := auth.NewInMemorySessionTokenStore()
|
||||||
|
|
||||||
|
// ===== OIDC + User Session(OB5:唯一認證路徑) =====
|
||||||
|
// cfg.Validate() 已確保所有必填欄位存在,這裡可以放心 wire。
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
oidcProvider, err := oidc.NewProvider(ctx, oidc.ProviderConfig{
|
||||||
|
IssuerURL: cfg.OIDC.IssuerURL,
|
||||||
|
ClientID: cfg.OIDC.ClientID,
|
||||||
|
ClientSecret: cfg.OIDC.ClientSecret,
|
||||||
|
RedirectURL: cfg.OIDC.RedirectURL,
|
||||||
|
})
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("failed to init OIDC provider",
|
||||||
|
"error", err,
|
||||||
|
"issuer", cfg.OIDC.IssuerURL,
|
||||||
|
"hint", "確認 IdP discovery (.well-known/openid-configuration) 可達")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
userSessionStore := usersession.NewInMemoryStore()
|
||||||
|
userSessionMgr := usersession.NewManager(userSessionStore, usersession.CookieConfig{
|
||||||
|
Name: cfg.UserSession.CookieName,
|
||||||
|
Domain: cfg.UserSession.CookieDomain,
|
||||||
|
Path: "/",
|
||||||
|
Secure: cfg.UserSession.CookieSecure,
|
||||||
|
HTTPOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: int(cfg.UserSession.AbsoluteTTL.Seconds()),
|
||||||
|
SigningKey: []byte(cfg.UserSession.Secret),
|
||||||
|
})
|
||||||
|
log.Info("OIDC initialized",
|
||||||
|
"issuer", cfg.OIDC.IssuerURL,
|
||||||
|
"client_id", cfg.OIDC.ClientID,
|
||||||
|
"redirect_url", cfg.OIDC.RedirectURL,
|
||||||
|
"frontend_url", cfg.OIDC.PostLoginURL,
|
||||||
|
"cookie_secure", cfg.UserSession.CookieSecure,
|
||||||
|
"absolute_ttl", cfg.UserSession.AbsoluteTTL,
|
||||||
|
"idle_ttl", cfg.UserSession.IdleTTL,
|
||||||
|
)
|
||||||
|
|
||||||
|
// ===== Session(api-server 端透過 ProxyClient 走 internal HTTP) =====
|
||||||
|
proxyClient := session.NewHTTPProxyClient(cfg.Session.ProxyInternalURL, log)
|
||||||
|
forwarder := session.NewForwarder(cfg.Session.ProxyInternalURL, log)
|
||||||
|
sessionStore := session.NewProxyClientStore(proxyClient, forwarder)
|
||||||
|
log.Info("session store initialized",
|
||||||
|
"backend", "proxy-client",
|
||||||
|
"proxy_internal_url", cfg.Session.ProxyInternalURL)
|
||||||
|
|
||||||
|
// ===== Repositories(in-memory,雛形) =====
|
||||||
|
deviceRepo := device.NewInMemoryRepository()
|
||||||
|
modelRepo := model.NewInMemoryRepository()
|
||||||
|
|
||||||
|
// ===== Converter(stub,Phase 2 才實作) =====
|
||||||
|
converterClient := converter.NewStubClient()
|
||||||
|
|
||||||
|
// ===== Seed demo data(可選) =====
|
||||||
|
if cfg.Server.SeedDemoData {
|
||||||
|
if err := seedDemoData(deviceRepo, modelRepo, pairingStore, cfg.Auth.StaticUserID, log); err != nil {
|
||||||
|
log.Warn("seed demo data failed", "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== API Router =====
|
||||||
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
// Phase 0.7 security fix C1:StaticUserID 不再注入 Deps(見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// dev seed 仍直接讀 cfg.Auth.StaticUserID;stage/prod 不影響(VISIONA_SEED_DEMO_DATA=false)。
|
||||||
|
router := api.NewRouter(api.Deps{
|
||||||
|
Logger: log,
|
||||||
|
PairingStore: pairingStore,
|
||||||
|
SessionTokenStore: sessionTokenStore,
|
||||||
|
SessionStore: sessionStore,
|
||||||
|
Forwarder: forwarder,
|
||||||
|
DeviceRepo: deviceRepo,
|
||||||
|
ModelRepo: modelRepo,
|
||||||
|
Storage: storageStore,
|
||||||
|
Converter: converterClient,
|
||||||
|
MaxUploadSizeMB: cfg.Model.MaxSizeMB,
|
||||||
|
CORSAllowedOrigins: cfg.CORS.AllowedOrigins,
|
||||||
|
RelayPublicURL: cfg.Server.RelayPublicURL,
|
||||||
|
|
||||||
|
// OIDC(OB5:唯一認證路徑)
|
||||||
|
OIDCProvider: oidcProvider,
|
||||||
|
SessionManager: userSessionMgr,
|
||||||
|
OIDCPostLoginURL: cfg.OIDC.PostLoginURL,
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := net.JoinHostPort(cfg.Server.Host, strconv.Itoa(cfg.Server.Port))
|
||||||
|
srv := &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second, // 防 slow-loris(對齊 security.md)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===== User session cleanup goroutine =====
|
||||||
|
cleanupCtx, cleanupCancel := context.WithCancel(context.Background())
|
||||||
|
defer cleanupCancel()
|
||||||
|
go runUserSessionCleanup(cleanupCtx, userSessionStore, cfg.UserSession.IdleTTL, cfg.UserSession.AbsoluteTTL, log)
|
||||||
|
|
||||||
|
// ===== 啟動 server =====
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
log.Info("api-server listening",
|
||||||
|
"addr", addr,
|
||||||
|
"proxy_internal_url", cfg.Session.ProxyInternalURL,
|
||||||
|
"seed_demo_data", cfg.Server.SeedDemoData,
|
||||||
|
"oidc_issuer", cfg.OIDC.IssuerURL,
|
||||||
|
)
|
||||||
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等 signal 或錯誤
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
select {
|
||||||
|
case <-quit:
|
||||||
|
log.Info("shutdown signal received")
|
||||||
|
case err := <-errCh:
|
||||||
|
log.Error("api-server error, shutting down", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
|
||||||
|
defer cancel()
|
||||||
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Warn("api-server shutdown error", "error", err)
|
||||||
|
}
|
||||||
|
cleanupCancel() // 停掉 user session cleanup goroutine
|
||||||
|
log.Info("api-server stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// runUserSessionCleanup 是 OIDC user session store 的 background cleanup 迴圈。
|
||||||
|
//
|
||||||
|
// 每 sessionCleanupInterval 跑一次 store.CleanupExpired,把 idle / absolute timeout
|
||||||
|
// 的 session 清掉。失敗只 log 不 panic(cleanup 不應拖垮主 process)。
|
||||||
|
//
|
||||||
|
// ctx 取消(process shutdown)即退出。
|
||||||
|
func runUserSessionCleanup(ctx context.Context, store usersession.Store, idleTTL, absTTL time.Duration, log loggerLike) {
|
||||||
|
ticker := time.NewTicker(sessionCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cctx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||||
|
removed, err := store.CleanupExpired(cctx, idleTTL, absTTL)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("user session cleanup failed", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
log.Info("user session cleanup", "removed", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loggerLike 是 runUserSessionCleanup 需要的最小 logger 介面,避免直接綁 *slog.Logger
|
||||||
|
// 而能在 test 中 stub。
|
||||||
|
type loggerLike interface {
|
||||||
|
Info(msg string, args ...any)
|
||||||
|
Warn(msg string, args ...any)
|
||||||
|
}
|
||||||
389
visionA-backend/cmd/api-server/oidc_e2e_test.go
Normal file
389
visionA-backend/cmd/api-server/oidc_e2e_test.go
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
// oidc_e2e_test.go — OIDC BFF end-to-end 整合測試。
|
||||||
|
//
|
||||||
|
// OB5(2026-04-26)起 OIDC 是唯一認證路徑、setupFixture 預設就 wire 好 fake OIDC,
|
||||||
|
// 因此本檔案不再用 build tag 隔離 — 屬於主測試套件的一部分。
|
||||||
|
//
|
||||||
|
// 涵蓋情境:
|
||||||
|
// - Happy path:login → IdP → callback → me → logout
|
||||||
|
// - State mismatch(CSRF 防護)
|
||||||
|
// - Invalid nonce(replay 攻擊)
|
||||||
|
// - Token exchange 失敗(IdP 不可達)
|
||||||
|
// - Pairing token 綁到 OIDC sub(oidc-tdd.md §9 關鍵驗證)
|
||||||
|
// - 多 user isolation(兩 user 各自的 token 不混淆)
|
||||||
|
//
|
||||||
|
// # 對齊文件
|
||||||
|
//
|
||||||
|
// - .autoflow/04-architecture/oidc-tdd.md §3 BFF Flow 詳細時序圖
|
||||||
|
// - .autoflow/04-architecture/oidc-tdd.md §9 Pairing 流程確認 user binding 仍正確
|
||||||
|
// - .autoflow/04-architecture/adr/adr-010-oidc-bff.md
|
||||||
|
// - .autoflow/04-architecture/adr/adr-011-supersede-adr-005.md
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/oidctest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────────
|
||||||
|
// E2E TEST CASES
|
||||||
|
// ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// TestOIDCE2E_FullLoginFlow 是 OIDC e2e 的核心 happy path 測試。
|
||||||
|
//
|
||||||
|
// 完整流程:
|
||||||
|
//
|
||||||
|
// 1. GET /api/auth/login → 302 to fakeOIDC /authorize
|
||||||
|
// 2. (sim) GET fakeOIDC /authorize → 302 to backend /api/auth/callback?code=...&state=...
|
||||||
|
// 3. GET /api/auth/callback → backend 完成 token exchange + 建 cookie session → 302 to PostLoginURL
|
||||||
|
// 4. GET /api/auth/me → 200 + 預期 user_id (= OIDC sub)
|
||||||
|
// 5. POST /api/auth/logout → 200 + clear cookie
|
||||||
|
// 6. GET /api/auth/me → 401
|
||||||
|
func TestOIDCE2E_FullLoginFlow(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// 預先設定 fake server 下一個 /token 簽出來的 id_token claims
|
||||||
|
const wantSub = "sub-oidc-e2e-001"
|
||||||
|
const wantEmail = "alice@innovedus.com"
|
||||||
|
const wantName = "Alice OIDC"
|
||||||
|
f.fakeOIDC.SetNextIDTokenClaims(map[string]any{
|
||||||
|
"sub": wantSub,
|
||||||
|
"email": wantEmail,
|
||||||
|
"name": wantName,
|
||||||
|
})
|
||||||
|
|
||||||
|
client := newCookieClient(t)
|
||||||
|
|
||||||
|
// ─── 1. GET /api/auth/login → 302 to fake IdP /authorize ───
|
||||||
|
loc1 := getExpect302(t, client, f.apiServer.URL+"/api/auth/login")
|
||||||
|
require.True(t, strings.HasPrefix(loc1, f.fakeOIDC.URL+"/authorize"),
|
||||||
|
"login 應 302 to fake IdP /authorize,得 %s", loc1)
|
||||||
|
|
||||||
|
// 驗 backend 帶的 query 參數符合 OIDC spec
|
||||||
|
authorizeURL, err := url.Parse(loc1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := authorizeURL.Query()
|
||||||
|
assert.Equal(t, "code", q.Get("response_type"))
|
||||||
|
assert.Equal(t, fixtureOIDCClientID, q.Get("client_id"))
|
||||||
|
assert.NotEmpty(t, q.Get("state"), "必帶 state(CSRF 防護)")
|
||||||
|
assert.NotEmpty(t, q.Get("nonce"), "必帶 nonce(replay 防護)")
|
||||||
|
assert.NotEmpty(t, q.Get("code_challenge"), "必帶 PKCE challenge")
|
||||||
|
assert.Equal(t, "S256", q.Get("code_challenge_method"))
|
||||||
|
|
||||||
|
// ─── 2. 模擬使用者「登入並同意」→ fake IdP 回 callback URL ───
|
||||||
|
callbackURL := f.fakeOIDC.SimulateAuthorizationFlow(t, loc1)
|
||||||
|
|
||||||
|
// ─── 3. GET callback → backend 換 token + 建 session → 302 to PostLoginURL ───
|
||||||
|
loc2 := getExpect302(t, client, callbackURL)
|
||||||
|
assert.NotEmpty(t, loc2, "callback 應 302 to PostLoginURL")
|
||||||
|
|
||||||
|
// 驗 cookie 已 set
|
||||||
|
assertHasSessionCookie(t, client, f.apiServer.URL)
|
||||||
|
|
||||||
|
// ─── 4. GET /api/auth/me → 200 + 預期 user_id ───
|
||||||
|
meResp := getJSON(t, client, f.apiServer.URL+"/api/auth/me")
|
||||||
|
require.Equal(t, http.StatusOK, meResp.status, "body=%v", meResp.body)
|
||||||
|
data := meResp.body["data"].(map[string]any)
|
||||||
|
assert.Equal(t, wantSub, data["user_id"], "user_id 應為 OIDC sub")
|
||||||
|
assert.Equal(t, wantEmail, data["email"])
|
||||||
|
|
||||||
|
// ─── 5. POST /api/auth/logout ───
|
||||||
|
logoutReq, _ := http.NewRequest(http.MethodPost, f.apiServer.URL+"/api/auth/logout", nil)
|
||||||
|
logoutResp, err := client.Do(logoutReq)
|
||||||
|
require.NoError(t, err)
|
||||||
|
logoutResp.Body.Close()
|
||||||
|
assert.True(t, logoutResp.StatusCode == http.StatusNoContent || logoutResp.StatusCode == http.StatusOK,
|
||||||
|
"logout 應為 204 或 200,得 %d", logoutResp.StatusCode)
|
||||||
|
|
||||||
|
// ─── 6. GET /api/auth/me 應 401 ───
|
||||||
|
meResp2 := getJSON(t, client, f.apiServer.URL+"/api/auth/me")
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, meResp2.status,
|
||||||
|
"logout 後 /api/auth/me 應回 401")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCE2E_StateMismatch 驗 callback 收到的 state 與 pending session 不符 → 4xx。
|
||||||
|
//
|
||||||
|
// 真實攻擊場景:攻擊者在 victim 的 browser 上塞自己的 state,企圖讓 victim 用攻擊者
|
||||||
|
// 的帳號登入(CSRF)。這個 test 確保 BFF 真的有比 state。
|
||||||
|
func TestOIDCE2E_StateMismatch(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := newCookieClient(t)
|
||||||
|
|
||||||
|
// 1. /login → 取 authorize URL
|
||||||
|
loc1 := getExpect302(t, client, f.apiServer.URL+"/api/auth/login")
|
||||||
|
|
||||||
|
// 2. 模擬 IdP redirect 但「篡改 state」
|
||||||
|
cb := f.fakeOIDC.SimulateAuthorizationFlow(t, loc1)
|
||||||
|
tampered := tamperState(t, cb, "evil-state-not-the-one-backend-stored")
|
||||||
|
|
||||||
|
// 3. backend 應拒絕
|
||||||
|
resp, err := client.Get(tampered)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.True(t, resp.StatusCode >= 400 && resp.StatusCode < 500,
|
||||||
|
"state mismatch 應回 4xx,得 %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCE2E_InvalidNonce 驗 id_token 的 nonce 與 backend 期待的不符 → 認證失敗。
|
||||||
|
//
|
||||||
|
// 模擬 replay 攻擊:攻擊者拿到一個其他登入流程的 id_token,企圖用它通過驗證。
|
||||||
|
func TestOIDCE2E_InvalidNonce(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// 把 id_token 的 nonce 故意覆寫成「跟 authorize 收到的不同」
|
||||||
|
f.fakeOIDC.SetNextIDTokenClaims(map[string]any{
|
||||||
|
"sub": "sub-replay-attempt",
|
||||||
|
"nonce": "this-is-a-stale-or-stolen-nonce",
|
||||||
|
})
|
||||||
|
|
||||||
|
client := newCookieClient(t)
|
||||||
|
loc1 := getExpect302(t, client, f.apiServer.URL+"/api/auth/login")
|
||||||
|
cb := f.fakeOIDC.SimulateAuthorizationFlow(t, loc1)
|
||||||
|
|
||||||
|
resp, err := client.Get(cb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.True(t, resp.StatusCode >= 400 && resp.StatusCode < 500,
|
||||||
|
"nonce mismatch 應回 4xx,得 %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCE2E_TokenExchangeFails 驗 IdP 5xx 時 backend 優雅 fail 而非 panic / 500。
|
||||||
|
//
|
||||||
|
// 提前關掉 fake server 模擬 IdP 不可達。
|
||||||
|
func TestOIDCE2E_TokenExchangeFails(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := newCookieClient(t)
|
||||||
|
loc1 := getExpect302(t, client, f.apiServer.URL+"/api/auth/login")
|
||||||
|
cb := f.fakeOIDC.SimulateAuthorizationFlow(t, loc1)
|
||||||
|
|
||||||
|
// 在 callback 發送之前把 fake IdP 關掉,模擬「token endpoint 連不上」
|
||||||
|
f.fakeOIDC.Close()
|
||||||
|
|
||||||
|
resp, err := client.Get(cb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
// 預期 backend 回 502 / 503(IdP 不可達)— 重點是不 panic
|
||||||
|
assert.True(t, resp.StatusCode >= 500 && resp.StatusCode < 600,
|
||||||
|
"IdP 不可達應回 5xx,得 %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCE2E_PairingTokenBindsToOIDCUser 是本任務最關鍵的測試(oidc-tdd.md §9)。
|
||||||
|
//
|
||||||
|
// 驗證:OIDC 登入完成後,使用者建立的 Pairing Token 綁定的 user_id
|
||||||
|
// **是 OIDC sub**(不再是 StaticAuthProvider 的「demo-user」)。
|
||||||
|
//
|
||||||
|
// 為什麼關鍵:oidc-tdd.md §9 承諾「Pairing 流程零影響」— 但 user_id 從
|
||||||
|
// 「demo-user」變成「OIDC sub」這個改動會穿透到 PairingStore,如果沒把
|
||||||
|
// UserContext.UserID 正確改成 sub,pairing 會繼續用「demo-user」綁所有人,
|
||||||
|
// 多用戶上線時會直接災難性混亂(一個人的 device 連到別人的帳號上)。
|
||||||
|
func TestOIDCE2E_PairingTokenBindsToOIDCUser(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
const wantSub = "sub-pairing-binding-test-001"
|
||||||
|
f.fakeOIDC.SetNextIDTokenClaims(map[string]any{
|
||||||
|
"sub": wantSub,
|
||||||
|
"email": "pairing-test@innovedus.com",
|
||||||
|
"name": "Pairing Test",
|
||||||
|
})
|
||||||
|
|
||||||
|
client := newCookieClient(t)
|
||||||
|
|
||||||
|
// 1. 走完整 OIDC 登入
|
||||||
|
loc1 := getExpect302(t, client, f.apiServer.URL+"/api/auth/login")
|
||||||
|
cb := f.fakeOIDC.SimulateAuthorizationFlow(t, loc1)
|
||||||
|
getExpect302(t, client, cb)
|
||||||
|
assertHasSessionCookie(t, client, f.apiServer.URL)
|
||||||
|
|
||||||
|
// 2. 確認 /me 回 OIDC sub
|
||||||
|
meResp := getJSON(t, client, f.apiServer.URL+"/api/auth/me")
|
||||||
|
require.Equal(t, http.StatusOK, meResp.status)
|
||||||
|
assert.Equal(t, wantSub,
|
||||||
|
meResp.body["data"].(map[string]any)["user_id"],
|
||||||
|
"前置:OIDC sub 應正確注入 UserContext")
|
||||||
|
|
||||||
|
// 3. 建立 Pairing Token(走 AuthMiddleware)
|
||||||
|
tokResp := postJSON(t, client, f.apiServer.URL+"/api/pairing/token", nil)
|
||||||
|
require.Equal(t, http.StatusOK, tokResp.status, "body=%v", tokResp.body)
|
||||||
|
pairingToken := tokResp.body["data"].(map[string]any)["token"].(string)
|
||||||
|
require.True(t, auth.IsValidPairingToken(pairingToken),
|
||||||
|
"應為合法 pairing token:%s", pairingToken)
|
||||||
|
|
||||||
|
// 4. **核心斷言**:用 PairingStore Validate 檢查綁定的 user_id 是 OIDC sub。
|
||||||
|
//
|
||||||
|
// 從 fixture 取出 PairingStore 直接驗(OB5 起 testFixture 已 expose pairingStore 欄位)。
|
||||||
|
tokInfo, err := f.pairingStore.Validate(context.Background(), pairingToken)
|
||||||
|
require.NoError(t, err, "pairing token 應仍可驗證")
|
||||||
|
assert.Equal(t, wantSub, tokInfo.UserID,
|
||||||
|
"pairing token 應綁到 OIDC sub(不再是 demo-user);"+
|
||||||
|
"若失敗代表 OB3 沒把 UserContext.UserID 設為 OIDC sub,"+
|
||||||
|
"或 PairingStore.Create 沒收到正確的 user_id")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCE2E_MultiUserIsolation 確保兩個 OIDC 使用者建立的 pairing token 不會混淆。
|
||||||
|
//
|
||||||
|
// 從 oidc-tdd.md §9 的角度看,這是「demo-user → OIDC sub」遷移後最容易藏的 bug:
|
||||||
|
// 兩個 user A / B 各自登入,A 建一個 pairing token,B 應該看不到。
|
||||||
|
func TestOIDCE2E_MultiUserIsolation(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// ─ Alice ─
|
||||||
|
clientA := f.AuthenticatedClient(t, "user-alice", "alice@x.com")
|
||||||
|
|
||||||
|
// Alice 建一個 pairing token
|
||||||
|
tokRespA := postJSON(t, clientA, f.apiServer.URL+"/api/pairing/token", nil)
|
||||||
|
require.Equal(t, http.StatusOK, tokRespA.status)
|
||||||
|
pairingA := tokRespA.body["data"].(map[string]any)["token"].(string)
|
||||||
|
require.True(t, auth.IsValidPairingToken(pairingA))
|
||||||
|
|
||||||
|
// ─ Bob ─
|
||||||
|
clientB := f.AuthenticatedClient(t, "user-bob", "bob@x.com")
|
||||||
|
|
||||||
|
// Bob 列出自己的 tokens —— 不應該看到 Alice 的
|
||||||
|
listResp := getJSON(t, clientB, f.apiServer.URL+"/api/pairing/tokens")
|
||||||
|
require.Equal(t, http.StatusOK, listResp.status)
|
||||||
|
bobTokens, _ := listResp.body["data"].([]any)
|
||||||
|
for _, raw := range bobTokens {
|
||||||
|
tok := raw.(map[string]any)
|
||||||
|
// list 回的是 token_prefix(前 12 字元),對比 Alice token 的 prefix
|
||||||
|
prefix, _ := tok["token_prefix"].(string)
|
||||||
|
assert.NotEqual(t, pairingA[:len(prefix)], prefix,
|
||||||
|
"Bob 的 token 列表不應包含 Alice 的 token prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 額外直接驗 store:Alice 名下確實有,Bob 名下沒有
|
||||||
|
aliceTokens, err := f.pairingStore.List(context.Background(), "user-alice")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, aliceTokens, "Alice 名下應有 pairing token")
|
||||||
|
|
||||||
|
bobStoreTokens, err := f.pairingStore.List(context.Background(), "user-bob")
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, tok := range bobStoreTokens {
|
||||||
|
assert.NotEqual(t, "user-alice", tok.UserID,
|
||||||
|
"Bob 名下的 token UserID 不應為 user-alice")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ──────────────────────────────────────────────────────────────
|
||||||
|
// HTTP CLIENT HELPERS
|
||||||
|
// ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
// newCookieClient 回傳一個會記 cookie 但「不自動跟隨 redirect」的 http.Client。
|
||||||
|
//
|
||||||
|
// 不自動 redirect 是必要的:BFF flow 一連串 302(login → IdP authorize → callback
|
||||||
|
// → PostLoginURL)要由我們自己一段一段控制,才能在中間 assert 每一步的 status / Location。
|
||||||
|
func newCookieClient(t *testing.T) *http.Client {
|
||||||
|
t.Helper()
|
||||||
|
jar, err := cookiejar.New(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return &http.Client{
|
||||||
|
Jar: jar,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getExpect302 GET 一個 URL 並斷言它回 302,回傳 Location header。
|
||||||
|
func getExpect302(t *testing.T, client *http.Client, target string) string {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := client.Get(target)
|
||||||
|
require.NoError(t, err, "GET %s", target)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Truef(t, resp.StatusCode == http.StatusFound || resp.StatusCode == http.StatusSeeOther,
|
||||||
|
"預期 302/303,得 %d (%s)", resp.StatusCode, target)
|
||||||
|
loc := resp.Header.Get("Location")
|
||||||
|
require.NotEmpty(t, loc, "Location header 應非空 (%s)", target)
|
||||||
|
return loc
|
||||||
|
}
|
||||||
|
|
||||||
|
type jsonResp struct {
|
||||||
|
status int
|
||||||
|
body map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
func getJSON(t *testing.T, client *http.Client, target string) jsonResp {
|
||||||
|
t.Helper()
|
||||||
|
resp, err := client.Get(target)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
out := jsonResp{status: resp.StatusCode, body: map[string]any{}}
|
||||||
|
if len(body) > 0 {
|
||||||
|
_ = json.Unmarshal(body, &out.body)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func postJSON(t *testing.T, client *http.Client, target string, body io.Reader) jsonResp {
|
||||||
|
t.Helper()
|
||||||
|
contentType := "application/json"
|
||||||
|
if body == nil {
|
||||||
|
body = bytes.NewReader(nil)
|
||||||
|
contentType = ""
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, target, body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if contentType != "" {
|
||||||
|
req.Header.Set("Content-Type", contentType)
|
||||||
|
}
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
raw, _ := io.ReadAll(resp.Body)
|
||||||
|
out := jsonResp{status: resp.StatusCode, body: map[string]any{}}
|
||||||
|
if len(raw) > 0 {
|
||||||
|
_ = json.Unmarshal(raw, &out.body)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// assertHasSessionCookie 驗 cookie jar 含 visiona_session cookie。
|
||||||
|
func assertHasSessionCookie(t *testing.T, client *http.Client, baseURL string) {
|
||||||
|
t.Helper()
|
||||||
|
u, err := url.Parse(baseURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
for _, c := range client.Jar.Cookies(u) {
|
||||||
|
if c.Name == "visiona_session" {
|
||||||
|
require.NotEmpty(t, c.Value, "visiona_session cookie 應有值")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
t.Fatalf("未找到 visiona_session cookie;jar=%+v", client.Jar.Cookies(u))
|
||||||
|
}
|
||||||
|
|
||||||
|
// tamperState 把 callback URL 的 state 換成另一個值(模擬攻擊者)。
|
||||||
|
func tamperState(t *testing.T, callbackURL, newState string) string {
|
||||||
|
t.Helper()
|
||||||
|
u, err := url.Parse(callbackURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("state", newState)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
return u.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 確保 oidctest 一定 import 到(避免未來 helper 改動時被 lint 掉)。
|
||||||
|
var _ = oidctest.WithClientCredentials
|
||||||
92
visionA-backend/cmd/api-server/oidc_test_helper_test.go
Normal file
92
visionA-backend/cmd/api-server/oidc_test_helper_test.go
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
// oidc_test_helper_test.go — OB5 起測試共用的 OIDC 認證 helper。
|
||||||
|
//
|
||||||
|
// 為什麼需要這份:OB5 移除 StaticAuth 之後,所有走 AuthMiddleware 的 integration test
|
||||||
|
// 都必須先走完 OIDC login flow 拿 cookie。把這段樣板抽成 helper 讓每個 test
|
||||||
|
// 不必重複「fake OIDC server + login + callback + 取 cookie」這 30 行程式碼。
|
||||||
|
//
|
||||||
|
// 設計選擇:
|
||||||
|
// - 同個 oidctest.NewServer 在 fixture 整個生命週期共用 — 多個 client 各自登入即可
|
||||||
|
// - AuthenticatedClient 拿 cookie 後就和原本的 http.DefaultClient 行為一致,
|
||||||
|
// 之後每次打 /api/* 都自動帶 cookie,handler 看到的 UserContext = 預先 set 的 sub
|
||||||
|
// - 不暴露 fake OIDC URL 給 test caller — caller 透過 fixture method 操作即可
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/cookiejar"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AuthenticatedClient 走完整 OIDC login flow,回傳已帶 visiona_session cookie 的 *http.Client。
|
||||||
|
//
|
||||||
|
// userID / email:simulate 登入後 backend session 裡會記錄的 OIDC sub / email。
|
||||||
|
// 同個 fixture 可以呼叫多次(不同 userID),各自拿到獨立的 cookie jar,模擬多 user。
|
||||||
|
//
|
||||||
|
// 任何步驟出錯直接 t.Fatalf — caller 不必檢 err。
|
||||||
|
//
|
||||||
|
// 注意:回傳的 client 「會自動跟 redirect」(CheckRedirect=nil),方便 caller 直接打
|
||||||
|
// /api/* 端點不用自己處理 302。如果你需要做 BFF flow 的 step-by-step assert,請改用
|
||||||
|
// oidc_e2e_test.go 裡的 newCookieClient。
|
||||||
|
func (f *testFixture) AuthenticatedClient(t *testing.T, userID, email string) *http.Client {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if f.fakeOIDC == nil {
|
||||||
|
t.Fatalf("AuthenticatedClient: fixture.fakeOIDC is nil — fixture wasn't built with setupFixture (which now wires fake OIDC)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 設定下一輪 ExchangeCode 的 id_token claims
|
||||||
|
f.fakeOIDC.SetNextIDTokenClaims(map[string]any{
|
||||||
|
"sub": userID,
|
||||||
|
"email": email,
|
||||||
|
"name": userID, // 簡化:name 與 sub 一致,測試夠用
|
||||||
|
})
|
||||||
|
|
||||||
|
jar, err := cookiejar.New(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// login flow 期間用「不自動 redirect」客戶端控制 302 step-by-step
|
||||||
|
flowClient := &http.Client{
|
||||||
|
Jar: jar,
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: GET /api/auth/login → 應 302 to fake OIDC /authorize
|
||||||
|
loc := getExpect302(t, flowClient, f.apiServer.URL+"/api/auth/login")
|
||||||
|
require.True(t, strings.HasPrefix(loc, f.fakeOIDC.URL+"/authorize"),
|
||||||
|
"login 應 302 to fake OIDC /authorize,得 %s", loc)
|
||||||
|
|
||||||
|
// Step 2: 模擬 IdP 同意登入 → 拿 callback URL
|
||||||
|
callbackURL := f.fakeOIDC.SimulateAuthorizationFlow(t, loc)
|
||||||
|
|
||||||
|
// Step 3: GET callback → backend 完成 token exchange + 寫 cookie session → 302 to PostLoginURL
|
||||||
|
_ = getExpect302(t, flowClient, callbackURL)
|
||||||
|
|
||||||
|
// 驗 cookie 已 set
|
||||||
|
u, err := url.Parse(f.apiServer.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
cookies := flowClient.Jar.Cookies(u)
|
||||||
|
var sessCookie *http.Cookie
|
||||||
|
for _, c := range cookies {
|
||||||
|
if c.Name == "visiona_session" {
|
||||||
|
sessCookie = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, sessCookie, "expected visiona_session cookie after callback")
|
||||||
|
require.NotEmpty(t, sessCookie.Value, "visiona_session cookie 應有值")
|
||||||
|
|
||||||
|
// 回傳一個共用同個 cookie jar、但會自動跟 redirect 的 client,
|
||||||
|
// 讓 caller 寫 client.Get/Post 不必處理 302。
|
||||||
|
return &http.Client{
|
||||||
|
Jar: jar,
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
138
visionA-backend/cmd/api-server/pairing_exchange_test.go
Normal file
138
visionA-backend/cmd/api-server/pairing_exchange_test.go
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
// pairing_exchange_test.go — AB11: POST /api/pairing/exchange 的 end-to-end integration test。
|
||||||
|
//
|
||||||
|
// 覆蓋情境:
|
||||||
|
// - 產 Pairing Token(POST /api/pairing/token,走 AuthMiddleware → 需 OIDC cookie)
|
||||||
|
// - 拿 Pairing Token 換 Session Token(POST /api/pairing/exchange,不走 AuthMiddleware)
|
||||||
|
// - 拿 Session Token 連 tunnel(remote-proxy 只做格式驗證 → 應能接受 vAs_)
|
||||||
|
// - 驗證同一個 Pairing Token 無法重複兌換
|
||||||
|
//
|
||||||
|
// 雛形取捨:
|
||||||
|
// - remote-proxy 目前**不會**回頭驗證 Session Token 是否出自 api-server(選項 A)。
|
||||||
|
// 故本測試沒有驗證「跨進程 session store 同步」— 這留給 Phase 1 實作。
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/api"
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAB11_PairingExchange_EndToEnd 跑完整個雛形 exchange → tunnel-connect 流程。
|
||||||
|
func TestAB11_PairingExchange_EndToEnd(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
client := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
|
||||||
|
// 1. POST /api/pairing/token → 拿一個 Pairing Token(走 AuthMiddleware,OIDC cookie 放行)
|
||||||
|
tokResp, err := client.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tokResp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, tokResp.StatusCode)
|
||||||
|
|
||||||
|
var tokBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(tokResp.Body).Decode(&tokBody))
|
||||||
|
pairingTok := tokBody["data"].(map[string]any)["token"].(string)
|
||||||
|
require.True(t, auth.IsValidPairingToken(pairingTok), "token 格式應合法:%s", pairingTok)
|
||||||
|
|
||||||
|
// 2. POST /api/pairing/exchange → 換 Session Token(不走 AuthMiddleware)
|
||||||
|
reqBody, _ := json.Marshal(api.PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
exchResp, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json", bytes.NewReader(reqBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer exchResp.Body.Close()
|
||||||
|
bodyBytes, _ := io.ReadAll(exchResp.Body)
|
||||||
|
require.Equal(t, http.StatusOK, exchResp.StatusCode, "body: %s", string(bodyBytes))
|
||||||
|
|
||||||
|
var exchBody map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(bodyBytes, &exchBody))
|
||||||
|
data := exchBody["data"].(map[string]any)
|
||||||
|
sessionTok := data["session_token"].(string)
|
||||||
|
require.True(t, auth.IsValidSessionToken(sessionTok), "session_token 格式應合法:%s", sessionTok)
|
||||||
|
assert.NotEmpty(t, data["relay_url"])
|
||||||
|
assert.NotEmpty(t, data["account"])
|
||||||
|
assert.NotEmpty(t, data["expires_at"])
|
||||||
|
|
||||||
|
// account 應綁到 OIDC sub(OB5 升級的關鍵驗證 — 不再是 demo-user@...)
|
||||||
|
assert.Equal(t, "demo-user@visionA.local", data["account"],
|
||||||
|
"OB5 起 account 應 = OIDC sub + suffix;本 test 用 demo-user 當 sub")
|
||||||
|
|
||||||
|
// 3. 拿 Session Token 連 tunnel — remote-proxy 只做格式驗證,應該接受
|
||||||
|
stop := startFakeTunnelClient(t, f.tunnelSrv.URL, sessionTok, f.localBackend.URL[len("http://"):])
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// 等 session 建立(session 進 store 需要非同步 handshake)
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
summaries, err := f.store.List(ctx)
|
||||||
|
return err == nil && len(summaries) == 1
|
||||||
|
}, 2*time.Second, 50*time.Millisecond, "tunnel session 應該建立")
|
||||||
|
|
||||||
|
// 4. 同一 pairing token 再換一次 → 應該 401 PAIRING_TOKEN_USED
|
||||||
|
exchResp2, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json", bytes.NewReader(reqBody))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer exchResp2.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, exchResp2.StatusCode)
|
||||||
|
body2, _ := io.ReadAll(exchResp2.Body)
|
||||||
|
assert.Contains(t, string(body2), "PAIRING_TOKEN_USED")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAB11_PairingExchange_Unauth 驗證 /api/pairing/exchange 本身不受 AuthMiddleware 管控。
|
||||||
|
//
|
||||||
|
// OB5 起 AuthMiddleware 已是 OIDC(cookie),exchange endpoint 必須仍然能用「沒登入的
|
||||||
|
// 純 HTTP client」打通 — 因為 agent 端就是 unauthenticated 來換 session token 的。
|
||||||
|
func TestAB11_PairingExchange_Unauth(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// 先用 OIDC client 拿一個 pairing token(authenticated)
|
||||||
|
authClient := f.AuthenticatedClient(t, "demo-user", "demo@visiona.local")
|
||||||
|
tokResp, err := authClient.Post(f.apiServer.URL+"/api/pairing/token", "", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer tokResp.Body.Close()
|
||||||
|
var tokBody map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(tokResp.Body).Decode(&tokBody))
|
||||||
|
pairingTok := tokBody["data"].(map[string]any)["token"].(string)
|
||||||
|
|
||||||
|
// 送 exchange,刻意用「沒任何 cookie / Auth header」的 default client — 應該還是 200 OK
|
||||||
|
reqBody, _ := json.Marshal(api.PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
req, _ := http.NewRequest(http.MethodPost, f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
bytes.NewReader(reqBody))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAB11_PairingExchange_InvalidFormat 驗證不合法格式的 token 回 401 INVALID_PAIRING_TOKEN。
|
||||||
|
func TestAB11_PairingExchange_InvalidFormat(t *testing.T) {
|
||||||
|
f := setupFixture(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
resp, err := http.Post(f.apiServer.URL+"/api/pairing/exchange",
|
||||||
|
"application/json",
|
||||||
|
strings.NewReader(`{"pairing_token":"not-valid"}`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.Contains(t, string(body), "INVALID_PAIRING_TOKEN")
|
||||||
|
}
|
||||||
88
visionA-backend/cmd/api-server/seed.go
Normal file
88
visionA-backend/cmd/api-server/seed.go
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// seedDemoData 在啟動時塞入示範資料,方便本機開發 / demo 不必跑完整 pairing。
|
||||||
|
//
|
||||||
|
// 觸發條件:VISIONA_SEED_DEMO_DATA=true
|
||||||
|
//
|
||||||
|
// 內容:
|
||||||
|
// - 一個示範 device(KL520)
|
||||||
|
// - 一個示範 model(YOLOv5 Face)
|
||||||
|
// - 一個示範 pairing token(log 出來方便手動 copy)
|
||||||
|
//
|
||||||
|
// 注意:
|
||||||
|
// - 失敗只 log warning,不阻擋啟動
|
||||||
|
// - 重複呼叫會產生重複資料;本函式只該被呼叫一次(main 已保證)
|
||||||
|
// - **不要**在生產環境啟用此 flag
|
||||||
|
func seedDemoData(
|
||||||
|
devRepo device.Repository,
|
||||||
|
mdlRepo model.Repository,
|
||||||
|
pairings auth.PairingStore,
|
||||||
|
userID string,
|
||||||
|
log *slog.Logger,
|
||||||
|
) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// 1. Demo device
|
||||||
|
dev := &device.Device{
|
||||||
|
ID: "demo-device-" + uuid.NewString()[:8],
|
||||||
|
OwnerUserID: userID,
|
||||||
|
Name: "Demo KL520 (seeded)",
|
||||||
|
DeviceType: "kl520",
|
||||||
|
SerialNumber: "DEMO-SN-001",
|
||||||
|
RemoteStatus: device.RemoteStatusOffline,
|
||||||
|
Status: device.USBStatusUnknown,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
}
|
||||||
|
if err := devRepo.Save(ctx, dev); err != nil {
|
||||||
|
log.Warn("seed: device save failed", "error", err)
|
||||||
|
} else {
|
||||||
|
log.Info("seed: demo device created", "id", dev.ID, "name", dev.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Demo model
|
||||||
|
mdl := &model.Model{
|
||||||
|
ID: "demo-model-" + uuid.NewString()[:8],
|
||||||
|
OwnerUserID: userID,
|
||||||
|
Name: "YOLOv5 Face (seeded)",
|
||||||
|
TargetChip: "kl520",
|
||||||
|
FileSize: 1024 * 1024, // 1 MB
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
StorageKey: "models/" + userID + "/demo.nef",
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
UploadedAt: &now,
|
||||||
|
}
|
||||||
|
if err := mdlRepo.Save(ctx, mdl); err != nil {
|
||||||
|
log.Warn("seed: model save failed", "error", err)
|
||||||
|
} else {
|
||||||
|
log.Info("seed: demo model created", "id", mdl.ID, "name", mdl.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Demo pairing token(log plaintext 方便開發 — 雛形 demo 用,生產禁用)
|
||||||
|
pt, _, err := pairings.Create(ctx, userID, 24*time.Hour)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("seed: pairing token create failed", "error", err)
|
||||||
|
} else {
|
||||||
|
log.Info("seed: demo pairing token created (use for local-tool tunnel)",
|
||||||
|
"token", pt,
|
||||||
|
"ttl", "24h")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
0
visionA-backend/cmd/remote-proxy/.gitkeep
Normal file
0
visionA-backend/cmd/remote-proxy/.gitkeep
Normal file
198
visionA-backend/cmd/remote-proxy/main.go
Normal file
198
visionA-backend/cmd/remote-proxy/main.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
// Command remote-proxy 是 visionA-backend 的 tunnel server 端(雛形雙 binary 之一)。
|
||||||
|
//
|
||||||
|
// 它:
|
||||||
|
// - 接受 local agent 的 WebSocket upgrade(`/tunnel/connect`),建立 yamux tunnel
|
||||||
|
// - 唯一持有 session state(in-memory,不走 Redis;見 ADR-006)
|
||||||
|
// - 對 api-server 提供 internal HTTP API(`/internal/forward/http`、`/internal/session/:token`)
|
||||||
|
// - 定期清理過期 session(對齊 tunnel.md §4.2:10s 心跳、30s 判定掉線)
|
||||||
|
//
|
||||||
|
// 對應文件:
|
||||||
|
// - `.autoflow/04-architecture/TDD.md` §2.5 relay / §2.9 wsconn
|
||||||
|
// - `.autoflow/04-architecture/tunnel.md` §7.1 remote-proxy main 流程
|
||||||
|
// - `.autoflow/04-architecture/api/api-internal.md`
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"visiona-backend/internal/config"
|
||||||
|
"visiona-backend/internal/logger"
|
||||||
|
"visiona-backend/internal/relay"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultSigningSecret 與 config/load.go 保持一致 — 用於啟動時警告提示。
|
||||||
|
const defaultSigningSecret = "dev-signing-secret-do-not-use-in-prod"
|
||||||
|
|
||||||
|
// sessionCleanupInterval 清理過期 session 的週期,對齊 tunnel.md §4.2。
|
||||||
|
const sessionCleanupInterval = 30 * time.Second
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cfg := config.Load()
|
||||||
|
log := logger.New(cfg.Logger.Level).With("service", "remote-proxy")
|
||||||
|
|
||||||
|
// B2 M2 修補:storage signing secret 為預設值時印 warning。
|
||||||
|
// 雖然 remote-proxy 本身不直接用 storage,但 remote-proxy / api-server 共用
|
||||||
|
// 同一份 config,若 env 忘了設,兩個 binary 都該提醒。
|
||||||
|
if cfg.Auth.SigningSecret == defaultSigningSecret {
|
||||||
|
log.Warn("VISIONA_STORAGE_SIGNING_SECRET 仍為預設 dev 值",
|
||||||
|
"action", "請在生產環境設定環境變數 VISIONA_STORAGE_SIGNING_SECRET")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session store — remote-proxy 是 session state 的唯一來源
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
|
||||||
|
// Relay server(面向 local agent)
|
||||||
|
relaySrv := relay.NewServer(store, log, relay.Options{
|
||||||
|
KeepAliveInterval: cfg.Tunnel.HeartbeatInterval,
|
||||||
|
ConnectionWriteTimeout: 10 * time.Second,
|
||||||
|
})
|
||||||
|
// Internal server(面向 api-server)
|
||||||
|
internalSrv := relay.NewInternalServer(store, log)
|
||||||
|
|
||||||
|
// 對外 mux(tunnel port,面向 local agent)
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect)
|
||||||
|
tunnelMux.HandleFunc("/relay/status", relaySrv.HandleRelayStatus)
|
||||||
|
tunnelMux.HandleFunc("/healthz", healthzHandler)
|
||||||
|
|
||||||
|
// 內部 mux(internal port,面向 api-server)
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(internalMux)
|
||||||
|
internalMux.HandleFunc("/healthz", healthzHandler)
|
||||||
|
|
||||||
|
tunnelAddr := net.JoinHostPort(cfg.Server.Host, strconv.Itoa(cfg.Server.TunnelPort))
|
||||||
|
internalAddr := net.JoinHostPort(cfg.Server.Host, strconv.Itoa(cfg.Server.InternalPort))
|
||||||
|
|
||||||
|
tunnelServer := &http.Server{
|
||||||
|
Addr: tunnelAddr,
|
||||||
|
Handler: tunnelMux,
|
||||||
|
// ReadHeaderTimeout 防 slow-loris(對齊 security.md)
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
internalServer := &http.Server{
|
||||||
|
Addr: internalAddr,
|
||||||
|
Handler: internalMux,
|
||||||
|
ReadHeaderTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup goroutine — 每 30s 掃一次過期 session
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
sessionCleanupLoop(ctx, store, cfg.Tunnel.IdleTimeout, log)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 啟動兩個 HTTP server
|
||||||
|
errCh := make(chan error, 2)
|
||||||
|
go func() {
|
||||||
|
log.Info("tunnel server listening",
|
||||||
|
"addr", tunnelAddr,
|
||||||
|
"keepalive_interval", cfg.Tunnel.HeartbeatInterval.String(),
|
||||||
|
"idle_timeout", cfg.Tunnel.IdleTimeout.String())
|
||||||
|
if err := tunnelServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
log.Info("internal server listening", "addr", internalAddr)
|
||||||
|
if err := internalServer.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
errCh <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等 signal 或錯誤
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
select {
|
||||||
|
case <-quit:
|
||||||
|
log.Info("shutdown signal received")
|
||||||
|
case err := <-errCh:
|
||||||
|
log.Error("server error, shutting down", "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Graceful shutdown
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
_ = tunnelServer.Shutdown(shutdownCtx)
|
||||||
|
_ = internalServer.Shutdown(shutdownCtx)
|
||||||
|
relaySrv.Shutdown()
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// 結束時關閉所有 session 釋放資源
|
||||||
|
// B3 Review Minor #3 修補:原本用 CleanupExpired(ctx, 0) 當「清掉全部」的 flag
|
||||||
|
// 語意隱晦。改用明確命名的 helper,讓意圖清楚。
|
||||||
|
if removed, err := closeAllSessions(shutdownCtx, store); err != nil {
|
||||||
|
log.Warn("final session cleanup failed", "error", err, "removed", removed)
|
||||||
|
} else if removed > 0 {
|
||||||
|
log.Info("final session cleanup done", "removed", removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("remote-proxy stopped")
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeAllSessions 在關機時關閉所有 active session。
|
||||||
|
//
|
||||||
|
// 實作上仍複用 `store.CleanupExpired(ctx, 0)`(cutoff = now,幾乎所有
|
||||||
|
// LastHeartbeat.Before(now) 為 true),但把「0 表示清全部」這個
|
||||||
|
// 隱晦 convention 包在 helper 裡,讓 main.go 的意圖清晰。
|
||||||
|
//
|
||||||
|
// B3 Review Minor #3 修補:避免日後 CleanupExpired 若改語意(如「0 = 不清」)
|
||||||
|
// 造成 shutdown 靜默失敗。
|
||||||
|
func closeAllSessions(ctx context.Context, store session.Store) (int, error) {
|
||||||
|
return store.CleanupExpired(ctx, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// healthzHandler 簡易健康檢查 — K8s liveness / readiness 用。
|
||||||
|
func healthzHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"})
|
||||||
|
}
|
||||||
|
|
||||||
|
// sessionCleanupLoop 週期性呼叫 store.CleanupExpired。
|
||||||
|
//
|
||||||
|
// 行為對齊 tunnel.md §4.2:每 30s 掃一次,idleTimeout 預設 30s。
|
||||||
|
func sessionCleanupLoop(ctx context.Context, store session.Store, idleTimeout time.Duration, log *slog.Logger) {
|
||||||
|
if idleTimeout <= 0 {
|
||||||
|
log.Warn("idle_timeout 設為 0 或負數,停用 session cleanup")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(sessionCleanupInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
log.Info("session cleanup loop started",
|
||||||
|
"interval", sessionCleanupInterval.String(),
|
||||||
|
"idle_timeout", idleTimeout.String())
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
removed, err := store.CleanupExpired(ctx, idleTimeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("session cleanup failed", "error", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if removed > 0 {
|
||||||
|
log.Info("session cleanup removed expired sessions", "count", removed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
0
visionA-backend/docker/.gitkeep
Normal file
0
visionA-backend/docker/.gitkeep
Normal file
72
visionA-backend/docker/Dockerfile.api-server
Normal file
72
visionA-backend/docker/Dockerfile.api-server
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
# syntax=docker/dockerfile:1.6
|
||||||
|
#
|
||||||
|
# visionA-backend / api-server — multi-stage Docker image
|
||||||
|
#
|
||||||
|
# 設計原則(對齊 build-deploy.md §2 與 backend CLAUDE.md §9):
|
||||||
|
# - Multi-stage:builder 階段負責編譯,runtime 階段只帶 binary(image 最小化)
|
||||||
|
# - CGO_ENABLED=0:產出 static binary,可直接放進 alpine/distroless
|
||||||
|
# - Non-root user:降低 container escape 風險
|
||||||
|
# - HEALTHCHECK:container 層級健康檢查(K8s / docker-compose 會用到)
|
||||||
|
#
|
||||||
|
# Build:
|
||||||
|
# docker build -f docker/Dockerfile.api-server -t visiona/api-server:dev .
|
||||||
|
# 執行目錄為 visionA-backend/,因此 COPY . . 會把整個 backend 帶進 builder。
|
||||||
|
|
||||||
|
# ---- Stage 1: builder ----------------------------------------------------
|
||||||
|
FROM golang:1.26-alpine AS builder
|
||||||
|
|
||||||
|
# git 給 go mod download 用(部分 module path 會需要)
|
||||||
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
|
WORKDIR /src
|
||||||
|
|
||||||
|
# 先 COPY go.mod / go.sum,讓 dependency layer 可以被 cache(只有改依賴才重跑)
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
# 複製其餘原始碼
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# 編譯 api-server:
|
||||||
|
# - CGO_ENABLED=0:pure Go static binary(alpine 可以直接跑)
|
||||||
|
# - -ldflags="-s -w":strip debug info,縮小 binary 大小
|
||||||
|
# - -trimpath:去掉原始碼路徑,避免洩漏 builder 主機資訊
|
||||||
|
ENV CGO_ENABLED=0 GOOS=linux
|
||||||
|
RUN go build -trimpath -ldflags="-s -w" -o /out/api-server ./cmd/api-server
|
||||||
|
|
||||||
|
# ---- Stage 2: runtime ----------------------------------------------------
|
||||||
|
FROM alpine:3.19
|
||||||
|
|
||||||
|
# 安裝 curl 給 HEALTHCHECK 用 + ca-certificates 給未來 HTTPS out-bound 用
|
||||||
|
# (Phase 0 雛形沒 outbound HTTPS,但預裝不增加太多體積,避免日後踩雷)
|
||||||
|
RUN apk add --no-cache ca-certificates curl tzdata && \
|
||||||
|
addgroup -S -g 1001 visiona && \
|
||||||
|
adduser -S -u 1001 -G visiona visiona
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 建立 storage 目錄(讓 LocalFS backend 預設路徑可寫)
|
||||||
|
RUN mkdir -p /app/data/storage && chown -R visiona:visiona /app
|
||||||
|
|
||||||
|
# 複製 binary
|
||||||
|
COPY --from=builder --chown=visiona:visiona /out/api-server /app/api-server
|
||||||
|
|
||||||
|
# 切到非 root
|
||||||
|
USER visiona:visiona
|
||||||
|
|
||||||
|
# api-server 預設 listen 3721(對齊 local-tool,見 config/config.go ServerConfig.Port)
|
||||||
|
EXPOSE 3721
|
||||||
|
|
||||||
|
# 預設環境變數:容器內的儲存路徑。
|
||||||
|
# 實際部署時由 docker-compose 或 K8s ConfigMap / Secret 覆蓋。
|
||||||
|
ENV VISIONA_HOST=0.0.0.0 \
|
||||||
|
VISIONA_API_PORT=3721 \
|
||||||
|
VISIONA_STORAGE_LOCALFS_ROOT=/app/data/storage \
|
||||||
|
VISIONA_LOG_LEVEL=info
|
||||||
|
|
||||||
|
# Container 層級 healthcheck — docker / compose 會用。
|
||||||
|
# 30s 週期、3s timeout、連續 3 次失敗視為 unhealthy。
|
||||||
|
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||||
|
CMD curl -fsS http://localhost:3721/healthz || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/api-server"]
|
||||||
52
visionA-backend/docker/Dockerfile.remote-proxy
Normal file
52
visionA-backend/docker/Dockerfile.remote-proxy
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
# syntax=docker/dockerfile:1.6
|
||||||
|
#
|
||||||
|
# visionA-backend / remote-proxy — multi-stage Docker image
|
||||||
|
#
|
||||||
|
# 設計原則同 Dockerfile.api-server(見該檔 header)。
|
||||||
|
# 唯一差別:
|
||||||
|
# - build 的是 ./cmd/remote-proxy
|
||||||
|
# - 對外 expose 3800(tunnel WS,local agent 用)+ 3801(internal HTTP,api-server 用)
|
||||||
|
# - HEALTHCHECK 打 tunnel port 的 /healthz
|
||||||
|
|
||||||
|
# ---- Stage 1: builder ----------------------------------------------------
|
||||||
|
FROM golang:1.26-alpine AS builder
|
||||||
|
|
||||||
|
RUN apk add --no-cache git ca-certificates
|
||||||
|
|
||||||
|
WORKDIR /src
|
||||||
|
|
||||||
|
COPY go.mod go.sum ./
|
||||||
|
RUN go mod download
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
ENV CGO_ENABLED=0 GOOS=linux
|
||||||
|
RUN go build -trimpath -ldflags="-s -w" -o /out/remote-proxy ./cmd/remote-proxy
|
||||||
|
|
||||||
|
# ---- Stage 2: runtime ----------------------------------------------------
|
||||||
|
FROM alpine:3.19
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates curl tzdata && \
|
||||||
|
addgroup -S -g 1001 visiona && \
|
||||||
|
adduser -S -u 1001 -G visiona visiona
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder --chown=visiona:visiona /out/remote-proxy /app/remote-proxy
|
||||||
|
|
||||||
|
USER visiona:visiona
|
||||||
|
|
||||||
|
# 3800:tunnel server(面向 local agent,WebSocket upgrade)
|
||||||
|
# 3801:internal HTTP(面向 api-server,同 compose network 內互通)
|
||||||
|
EXPOSE 3800 3801
|
||||||
|
|
||||||
|
ENV VISIONA_HOST=0.0.0.0 \
|
||||||
|
VISIONA_TUNNEL_PORT=3800 \
|
||||||
|
VISIONA_PROXY_INTERNAL_PORT=3801 \
|
||||||
|
VISIONA_LOG_LEVEL=info
|
||||||
|
|
||||||
|
# Healthcheck 打 tunnel listener 的 /healthz(internal port 雖然也有但不對外)
|
||||||
|
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
||||||
|
CMD curl -fsS http://localhost:3800/healthz || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/remote-proxy"]
|
||||||
92
visionA-backend/docker/docker-compose.yml
Normal file
92
visionA-backend/docker/docker-compose.yml
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
# visionA-backend docker-compose
|
||||||
|
#
|
||||||
|
# 對應:.autoflow/04-architecture/build-deploy.md §4
|
||||||
|
#
|
||||||
|
# 服務拓撲(Phase 0 雛形):
|
||||||
|
#
|
||||||
|
# browser ──(3721)──▶ api-server ──(internal:3801)──▶ remote-proxy
|
||||||
|
# ▲
|
||||||
|
# │
|
||||||
|
# local agent ──(WS:3800)──┘
|
||||||
|
#
|
||||||
|
# - api-server 無狀態;session state 全在 remote-proxy in-memory
|
||||||
|
# - 兩個 service 用 compose 的 default bridge network(service 名稱互通)
|
||||||
|
# 因此 api-server 用 `http://remote-proxy:3801` 打 internal API
|
||||||
|
#
|
||||||
|
# 使用:
|
||||||
|
# cd visionA-backend
|
||||||
|
# cp .env.example .env # 首次
|
||||||
|
# docker compose -f docker/docker-compose.yml up -d
|
||||||
|
# docker compose -f docker/docker-compose.yml logs -f
|
||||||
|
# docker compose -f docker/docker-compose.yml down
|
||||||
|
|
||||||
|
services:
|
||||||
|
api-server:
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: docker/Dockerfile.api-server
|
||||||
|
image: visiona/api-server:dev
|
||||||
|
container_name: visiona-api-server
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
# 對外暴露給瀏覽器 / curl
|
||||||
|
- "${VISIONA_API_PORT:-3721}:3721"
|
||||||
|
env_file:
|
||||||
|
- ../.env
|
||||||
|
environment:
|
||||||
|
# 容器內固定值(不要讓 .env 的 host-specific URL 汙染到容器)
|
||||||
|
# .env 裡的 VISIONA_PROXY_INTERNAL_URL 可被這裡覆蓋,改指向 compose service name
|
||||||
|
VISIONA_HOST: "0.0.0.0"
|
||||||
|
VISIONA_API_PORT: "3721"
|
||||||
|
VISIONA_PROXY_INTERNAL_URL: "http://remote-proxy:3801"
|
||||||
|
# api-server 的 storage BaseURL — 注意:生產會是 https://api.example.com/storage,
|
||||||
|
# 雛形 demo 直接用 host 的 API port(LocalFS presigned URL 從瀏覽器打回來)
|
||||||
|
VISIONA_STORAGE_LOCALFS_BASE_URL: "${VISIONA_STORAGE_BASE_URL:-http://localhost:3721/storage}"
|
||||||
|
volumes:
|
||||||
|
# 模型檔持久化 — 避免 container 重建時上傳過的模型消失
|
||||||
|
- ./data/storage:/app/data/storage
|
||||||
|
depends_on:
|
||||||
|
remote-proxy:
|
||||||
|
condition: service_healthy
|
||||||
|
networks:
|
||||||
|
- visiona-net
|
||||||
|
healthcheck:
|
||||||
|
# 覆蓋 Dockerfile 裡的預設(放寬 start_period 給冷啟動用)
|
||||||
|
test: ["CMD", "curl", "-fsS", "http://localhost:3721/healthz"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 3s
|
||||||
|
start_period: 10s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
remote-proxy:
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: docker/Dockerfile.remote-proxy
|
||||||
|
image: visiona/remote-proxy:dev
|
||||||
|
container_name: visiona-remote-proxy
|
||||||
|
restart: unless-stopped
|
||||||
|
ports:
|
||||||
|
# 3800:tunnel WS,對外(local agent 要能從 host 或外部連進來)
|
||||||
|
- "${VISIONA_TUNNEL_PORT:-3800}:3800"
|
||||||
|
# 3801 internal 不對外 — 只有同 compose network 的 api-server 會打
|
||||||
|
# 若本機要 debug internal API,可臨時 uncomment 下行:
|
||||||
|
# - "${VISIONA_PROXY_INTERNAL_PORT:-3801}:3801"
|
||||||
|
env_file:
|
||||||
|
- ../.env
|
||||||
|
environment:
|
||||||
|
VISIONA_HOST: "0.0.0.0"
|
||||||
|
VISIONA_TUNNEL_PORT: "3800"
|
||||||
|
VISIONA_PROXY_INTERNAL_PORT: "3801"
|
||||||
|
networks:
|
||||||
|
- visiona-net
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-fsS", "http://localhost:3800/healthz"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 3s
|
||||||
|
start_period: 10s
|
||||||
|
retries: 3
|
||||||
|
|
||||||
|
networks:
|
||||||
|
visiona-net:
|
||||||
|
driver: bridge
|
||||||
|
name: visiona-net
|
||||||
58
visionA-backend/go.mod
Normal file
58
visionA-backend/go.mod
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
module visiona-backend
|
||||||
|
|
||||||
|
go 1.26
|
||||||
|
|
||||||
|
// 依賴狀態:
|
||||||
|
// - B2 引入 testify(單元測試)
|
||||||
|
// - B3 引入 gorilla/websocket + hashicorp/yamux(relay + tunnel client + wsconn)
|
||||||
|
// - B4 引入 gin-gonic/gin + gin-contrib/cors + google/uuid(api-server router / middleware / id)
|
||||||
|
// 後續任務會加入:
|
||||||
|
// - github.com/go-playground/validator/v10 (B5 request validation)
|
||||||
|
// - github.com/aws/aws-sdk-go-v2 (可選,S3 儲存層)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0
|
||||||
|
github.com/gin-contrib/cors v1.7.7
|
||||||
|
github.com/gin-gonic/gin v1.12.0
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
|
github.com/hashicorp/yamux v0.1.2
|
||||||
|
github.com/stretchr/testify v1.11.1
|
||||||
|
golang.org/x/oauth2 v0.36.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/bytedance/gopkg v0.1.3 // indirect
|
||||||
|
github.com/bytedance/sonic v1.15.0 // indirect
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 // indirect
|
||||||
|
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||||
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 // indirect
|
||||||
|
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||||
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 // indirect
|
||||||
|
github.com/goccy/go-json v0.10.5 // indirect
|
||||||
|
github.com/goccy/go-yaml v1.19.2 // indirect
|
||||||
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||||
|
github.com/leodido/go-urn v1.4.0 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/quic-go/qpack v0.6.0 // indirect
|
||||||
|
github.com/quic-go/quic-go v0.59.0 // indirect
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
|
github.com/ugorji/go/codec v1.3.1 // indirect
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 // indirect
|
||||||
|
golang.org/x/arch v0.23.0 // indirect
|
||||||
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
|
golang.org/x/net v0.51.0 // indirect
|
||||||
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
|
golang.org/x/text v0.35.0 // indirect
|
||||||
|
google.golang.org/protobuf v1.36.10 // indirect
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
|
)
|
||||||
111
visionA-backend/go.sum
Normal file
111
visionA-backend/go.sum
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
github.com/bytedance/gopkg v0.1.3 h1:TPBSwH8RsouGCBcMBktLt1AymVo2TVsBVCY4b6TnZ/M=
|
||||||
|
github.com/bytedance/gopkg v0.1.3/go.mod h1:576VvJ+eJgyCzdjS+c4+77QF3p7ubbtiKARP3TxducM=
|
||||||
|
github.com/bytedance/sonic v1.15.0 h1:/PXeWFaR5ElNcVE84U0dOHjiMHQOwNIx3K4ymzh/uSE=
|
||||||
|
github.com/bytedance/sonic v1.15.0/go.mod h1:tFkWrPz0/CUCLEF4ri4UkHekCIcdnkqXw9VduqpJh0k=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0 h1:gXH3KVnatgY7loH5/TkeVyXPfESoqSBSBEiDd5VjlgE=
|
||||||
|
github.com/bytedance/sonic/loader v0.5.0/go.mod h1:AR4NYCk5DdzZizZ5djGqQ92eEhCCcdf5x77udYiSJRo=
|
||||||
|
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||||
|
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0 h1:V9orjXynvu5wiC9SemFTWnG4F45v403aIcjWo0d41+A=
|
||||||
|
github.com/coreos/go-oidc/v3 v3.18.0/go.mod h1:DYCf24+ncYi+XkIH97GY1+dqoRlbaSI26KVTCI9SrY4=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12 h1:e9hWvmLYvtp846tLHam2o++qitpguFiYCKbn0w9jyqw=
|
||||||
|
github.com/gabriel-vasile/mimetype v1.4.12/go.mod h1:d+9Oxyo1wTzWdyVUPMmXFvp4F9tea18J8ufA774AB3s=
|
||||||
|
github.com/gin-contrib/cors v1.7.7 h1:Oh9joP463x7Mw72vhvJ61YQm8ODh9b04YR7vsOErD0Q=
|
||||||
|
github.com/gin-contrib/cors v1.7.7/go.mod h1:K5tW0RkzJtWSiOdikXloy8VEZlgdVNpHNw8FpjUPNrE=
|
||||||
|
github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w=
|
||||||
|
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||||
|
github.com/gin-gonic/gin v1.12.0 h1:b3YAbrZtnf8N//yjKeU2+MQsh2mY5htkZidOM7O0wG8=
|
||||||
|
github.com/gin-gonic/gin v1.12.0/go.mod h1:VxccKfsSllpKshkBWgVgRniFFAzFb9csfngsqANjnLc=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA=
|
||||||
|
github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||||
|
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||||
|
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||||
|
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||||
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1 h1:f3zDSN/zOma+w6+1Wswgd9fLkdwy06ntQJp0BBvFG0w=
|
||||||
|
github.com/go-playground/validator/v10 v10.30.1/go.mod h1:oSuBIQzuJxL//3MelwSLD5hc2Tu889bF0Idm9Dg26cM=
|
||||||
|
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||||
|
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||||
|
github.com/goccy/go-yaml v1.19.2 h1:PmFC1S6h8ljIz6gMRBopkjP1TVT7xuwrButHID66PoM=
|
||||||
|
github.com/goccy/go-yaml v1.19.2/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||||
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
|
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
|
||||||
|
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
|
||||||
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||||
|
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
|
||||||
|
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||||
|
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
|
||||||
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
|
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||||
|
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||||
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||||
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||||
|
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8=
|
||||||
|
github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII=
|
||||||
|
github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw=
|
||||||
|
github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
|
||||||
|
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||||
|
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||||
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
|
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
|
github.com/ugorji/go/codec v1.3.1 h1:waO7eEiFDwidsBN6agj1vJQ4AG7lh2yqXyOXqhgQuyY=
|
||||||
|
github.com/ugorji/go/codec v1.3.1/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0 h1:yXUhImUjjAInNcpTcAlPHiT7bIXhshCTL3jVBkF3xaE=
|
||||||
|
go.mongodb.org/mongo-driver/v2 v2.5.0/go.mod h1:yOI9kBsufol30iFsl1slpdq1I0eHPzybRWdyYUs8K/0=
|
||||||
|
go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y=
|
||||||
|
go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU=
|
||||||
|
golang.org/x/arch v0.23.0 h1:lKF64A2jF6Zd8L0knGltUnegD62JMFBiCPBmQpToHhg=
|
||||||
|
golang.org/x/arch v0.23.0/go.mod h1:dNHoOeKiyja7GTvF9NJS1l3Z2yntpQNzgrjh1cU103A=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo=
|
||||||
|
golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y=
|
||||||
|
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||||
|
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
|
||||||
|
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
|
||||||
|
google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE=
|
||||||
|
google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||||
|
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||||
|
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
0
visionA-backend/internal/api/.gitkeep
Normal file
0
visionA-backend/internal/api/.gitkeep
Normal file
181
visionA-backend/internal/api/api.go
Normal file
181
visionA-backend/internal/api/api.go
Normal file
@ -0,0 +1,181 @@
|
|||||||
|
// Package api 實作 api-server 的 REST + WebSocket 入口。
|
||||||
|
//
|
||||||
|
// 對齊 `.autoflow/04-architecture/api/api-spec.md`。
|
||||||
|
//
|
||||||
|
// **B4 範圍**:Router / Middleware / 結構化錯誤回應骨架 + 少數 handler(/healthz、
|
||||||
|
// /api/system/health、/api/system/info、/api/pairing/token、/api/pairing/status)。
|
||||||
|
//
|
||||||
|
// **B5 範圍**(本檔):
|
||||||
|
// - Auth:login / logout / me(stub),register → 501
|
||||||
|
// - Pairing:list tokens / revoke token
|
||||||
|
// - Devices:list / get(讀雲端 repo + 合併 tunnel 狀態),scan / connect /
|
||||||
|
// disconnect / flash / inference.start/stop 走 proxy,unpair 軟刪
|
||||||
|
// - Models:list / get / init upload / finalize / delete
|
||||||
|
// - System:/system/deps(走 proxy)
|
||||||
|
// - Clusters:GET /clusters 回空陣列;其他 stub
|
||||||
|
// - Storage:/storage/* 的 LocalFS 假 presigned URL 代理(GET/PUT)
|
||||||
|
// - WebSocket:保留 501 stub(詳見 stubs.go;B7 TODO)
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/converter"
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
"visiona-backend/internal/oidc"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deps 匯整 api router 所需的所有依賴;由 cmd/api-server/main.go 在啟動時注入。
|
||||||
|
//
|
||||||
|
// 之所以集中在一個 struct,是為了:
|
||||||
|
// 1. 讓 NewRouter 簽章穩定(之後加新依賴只改 struct,不破壞既有 caller)
|
||||||
|
// 2. integration test 容易組裝(只需建一個 Deps 物件)
|
||||||
|
// 3. 各 handler 透過 closure 取得依賴,不需要 global state
|
||||||
|
//
|
||||||
|
// OB5(2026-04-26)起 OIDC 是唯一認證路徑:OIDCProvider + SessionManager 必填,
|
||||||
|
// validate() 在 NewRouter 啟動時就會 panic 提前暴露 misconfiguration。
|
||||||
|
type Deps struct {
|
||||||
|
Logger *slog.Logger
|
||||||
|
|
||||||
|
// PairingStore 管理 Pairing Token 生命週期。為 nil 時 /api/pairing/* 會回 501。
|
||||||
|
PairingStore auth.PairingStore
|
||||||
|
|
||||||
|
// ─── OIDC(OB5 起為必填) ───
|
||||||
|
|
||||||
|
// OIDCProvider 封裝 OIDC client(authorization URL 組裝、token exchange、id_token 驗證)。
|
||||||
|
OIDCProvider oidc.Provider
|
||||||
|
|
||||||
|
// SessionManager 管理 cookie session(StartSession / GetSession / EndSession)。
|
||||||
|
SessionManager *usersession.Manager
|
||||||
|
|
||||||
|
// OIDCPostLoginURL 是 callback 完成後 302 回 frontend 的 base URL。
|
||||||
|
// 例:http://localhost:3000(dev)/ https://app.visiona.cloud(prod)。
|
||||||
|
// 為空字串時 callback handler 會 fallback 到 same-origin "/"(不建議生產配置)。
|
||||||
|
OIDCPostLoginURL string
|
||||||
|
|
||||||
|
SessionStore session.Store
|
||||||
|
Forwarder *session.Forwarder
|
||||||
|
|
||||||
|
DeviceRepo device.Repository
|
||||||
|
ModelRepo model.Repository
|
||||||
|
|
||||||
|
Storage storage.Store
|
||||||
|
Converter converter.Client
|
||||||
|
|
||||||
|
// CORSAllowedOrigins 是允許的瀏覽器 Origin 白名單;空 slice 預設放行
|
||||||
|
// http://localhost:3000(前端 dev server)。
|
||||||
|
CORSAllowedOrigins []string
|
||||||
|
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// StaticUserID 欄位已移除:multi-tenant 環境下 fallback 到固定 user 是 latent multi-user
|
||||||
|
// 隔離破口(OWASP A01 + A04)。改 OIDC 後 AuthMiddleware 會擋下未登入請求,
|
||||||
|
// handler 拿不到 UserContext 一律 500(safer than silent fallback)。
|
||||||
|
// dev seed / unit test 仍可獨立讀 cfg.Auth.StaticUserID env,不再注入 Deps。
|
||||||
|
|
||||||
|
// MaxUploadSizeMB 是模型上傳大小上限(MB);0 代表不限(測試友善)。
|
||||||
|
// 對齊 feature-model-management.md:Phase 0 預設 100 MB(由 config.Model.MaxSizeMB 注入)。
|
||||||
|
MaxUploadSizeMB int
|
||||||
|
|
||||||
|
// SessionTokenStore 保存 Pairing → Session 交換後發出的 Session Token。
|
||||||
|
// Phase 0 雛形用 in-memory 實作(由 main.go 注入);Phase 1 改為 DB-backed。
|
||||||
|
// 為 nil 時 /api/pairing/exchange 會回 501 NOT_IMPLEMENTED。
|
||||||
|
SessionTokenStore auth.SessionTokenStore
|
||||||
|
|
||||||
|
// RelayPublicURL 是 agent 連 tunnel 用的 WSS URL(對外可訪問)。
|
||||||
|
// 由 `POST /api/pairing/exchange` 回給 agent;若為空會回預設 `wss://relay.visionA.cloud`(雛形 placeholder)。
|
||||||
|
// 對齊 build-deploy.md 的 VISIONA_RELAY_PUBLIC_URL 環境變數。
|
||||||
|
RelayPublicURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate 確認必要欄位都有;在 NewRouter 啟動時呼叫,避免 nil pointer panic 推到 runtime。
|
||||||
|
//
|
||||||
|
// 嚴格欄位(缺則 panic — fail fast,避免半設定狀態跑進生產):
|
||||||
|
// - OIDCProvider — OB5 起 OIDC 是唯一認證路徑
|
||||||
|
// - SessionManager — OIDC cookie session 必須
|
||||||
|
//
|
||||||
|
// 寬鬆欄位(缺有預設):Logger / CORSAllowedOrigins
|
||||||
|
//
|
||||||
|
// 其他欄位(PairingStore / SessionStore 等)若為 nil 不擋 — 個別 handler 會回 501,
|
||||||
|
// 允許「最小骨架」啟動跑 /healthz。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security fix C1:移除 StaticUserID 預設 "demo-user" 的 fallback。
|
||||||
|
func (d *Deps) validate() {
|
||||||
|
if d.Logger == nil {
|
||||||
|
d.Logger = slog.Default()
|
||||||
|
}
|
||||||
|
if len(d.CORSAllowedOrigins) == 0 {
|
||||||
|
d.CORSAllowedOrigins = []string{"http://localhost:3000"}
|
||||||
|
}
|
||||||
|
if d.OIDCProvider == nil {
|
||||||
|
panic("api.NewRouter: Deps.OIDCProvider is required (OB5: OIDC is the only auth path)")
|
||||||
|
}
|
||||||
|
if d.SessionManager == nil {
|
||||||
|
panic("api.NewRouter: Deps.SessionManager is required (OB5: OIDC cookie session is mandatory)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter 建立 Gin engine 並註冊所有路由與中介層。
|
||||||
|
//
|
||||||
|
// 為何回 *gin.Engine 而非 http.Handler:cmd/api-server/main.go 需要 access
|
||||||
|
// engine.Run(也可拿 .Handler() 給標準 http.Server 用,所以這個選擇沒讓
|
||||||
|
// caller 失去彈性)。
|
||||||
|
func NewRouter(deps Deps) *gin.Engine {
|
||||||
|
deps.validate()
|
||||||
|
|
||||||
|
// gin 的 ReleaseMode 由 caller 視環境設定(cmd/api-server/main.go);
|
||||||
|
// 這裡不主動設,避免測試環境被汙染。
|
||||||
|
r := gin.New()
|
||||||
|
|
||||||
|
// 註冊全域 middleware(順序很重要:Recovery 第一,Logger 接著,CORS 之後)
|
||||||
|
r.Use(RecoveryMiddleware(deps.Logger))
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(LoggerMiddleware(deps.Logger))
|
||||||
|
r.Use(CORSMiddleware(deps.CORSAllowedOrigins))
|
||||||
|
r.Use(ErrorMiddleware()) // 統一把 c.Errors 轉成 JSON
|
||||||
|
|
||||||
|
// /healthz 不需要 auth — K8s liveness/readiness 用
|
||||||
|
r.GET("/healthz", HealthzHandler())
|
||||||
|
|
||||||
|
// /storage/* 不走 AuthMiddleware(改用 HMAC 簽章)— 對齊 api-spec.md §10
|
||||||
|
registerStorageRoutes(r, deps)
|
||||||
|
|
||||||
|
// /api/pairing/exchange 刻意不走 AuthMiddleware:
|
||||||
|
// agent 尚未有 session token 時就得用 Pairing Token 換 Session Token,
|
||||||
|
// Pairing Token 本身就是這個 endpoint 的憑證。詳見 security.md §1.2。
|
||||||
|
registerPairingPublicRoutes(r, deps)
|
||||||
|
|
||||||
|
// /ws/* 雛形全部 501;B7 補齊 WebSocket proxy
|
||||||
|
registerWebSocketStubs(r)
|
||||||
|
|
||||||
|
// OIDC public routes(不走 AuthMiddleware):
|
||||||
|
// - GET /api/auth/login — 起始登入流程(user 還沒登入)
|
||||||
|
// - GET /api/auth/callback — OIDC IdP 302 回來
|
||||||
|
// 必須註冊在 AuthMiddleware 群組之外,否則使用者沒登入根本進不去。
|
||||||
|
registerOIDCPublicRoutes(r, deps)
|
||||||
|
|
||||||
|
// /api 群組:所有路由都走 OIDC AuthMiddleware(cookie session → UserContext)
|
||||||
|
apiGroup := r.Group("/api")
|
||||||
|
apiGroup.Use(AuthMiddleware(deps))
|
||||||
|
|
||||||
|
// B4 核心
|
||||||
|
registerSystemRoutes(apiGroup, deps)
|
||||||
|
registerPairingRoutes(apiGroup, deps)
|
||||||
|
|
||||||
|
// B5 新增:實際 handler
|
||||||
|
registerAuthRoutes(apiGroup, deps)
|
||||||
|
registerDeviceRoutes(apiGroup, deps)
|
||||||
|
registerModelRoutes(apiGroup, deps)
|
||||||
|
registerClusterRoutes(apiGroup, deps)
|
||||||
|
|
||||||
|
// Stubs(只註冊「還沒有實際 handler」的那些 endpoint)
|
||||||
|
registerStubRoutes(apiGroup, deps)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
54
visionA-backend/internal/api/auth.go
Normal file
54
visionA-backend/internal/api/auth.go
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
// auth.go — /api/auth/* 的 handler 註冊。
|
||||||
|
//
|
||||||
|
// OB5(2026-04-26)起,OIDC 是唯一認證路徑:
|
||||||
|
// - GET /api/auth/login → 302 to IdP(registerOIDCPublicRoutes 在 apiGroup 之外)
|
||||||
|
// - GET /api/auth/callback → token exchange + cookie session(同上)
|
||||||
|
// - POST /api/auth/login → 410 Gone(指引使用者改用 GET)
|
||||||
|
// - POST /api/auth/logout → 清 cookie session(idempotent)
|
||||||
|
// - GET /api/auth/me → 從 cookie session 拿 user info
|
||||||
|
// - POST /api/auth/register → 501(註冊去 Member Center)
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §1(Auth)+ oidc-tdd.md §3 / §4.5。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerAuthRoutes 註冊 /api/auth/* 的 routes(OIDC 模式,唯一路徑)。
|
||||||
|
//
|
||||||
|
// 注意:GET /api/auth/login 與 /api/auth/callback 是「不需登入即可呼叫」的
|
||||||
|
// public endpoint,由 NewRouter 透過 registerOIDCPublicRoutes 直接註冊在 r 上
|
||||||
|
// (不在 apiGroup 中),不在這裡。
|
||||||
|
func registerAuthRoutes(g *gin.RouterGroup, deps Deps) {
|
||||||
|
g.POST("/auth/login", oidcLoginNotSupportedHandler())
|
||||||
|
g.POST("/auth/logout", oidcLogoutHandler(deps))
|
||||||
|
g.GET("/auth/me", oidcMeHandler(deps))
|
||||||
|
g.POST("/auth/register", authRegisterHandler())
|
||||||
|
}
|
||||||
|
|
||||||
|
// oidcLoginNotSupportedHandler 回 410 Gone,告訴 caller 改用 GET /api/auth/login。
|
||||||
|
//
|
||||||
|
// 為什麼選 410 而非 405:
|
||||||
|
// - 405 Method Not Allowed 暗示「同 URL 換 method 就行」 — 但語意上 OIDC login
|
||||||
|
// 是 redirect flow,不只是 method 換掉。
|
||||||
|
// - 410 Gone 明確表示「此資源已不存在於此 URL/method」,並可在訊息裡指引到正確端點。
|
||||||
|
func oidcLoginNotSupportedHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
WriteError(c, http.StatusGone, ErrCodeNotImplemented,
|
||||||
|
"OIDC mode: use GET /api/auth/login to start the redirect flow", nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// authRegisterHandler 實作 POST /api/auth/register。
|
||||||
|
//
|
||||||
|
// OIDC 模式下 visionA 不負責註冊 — 註冊是 Member Center 的職責。
|
||||||
|
// 一律回 501,前端可顯示「請至 Member Center 註冊」。
|
||||||
|
func authRegisterHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
WriteNotImplemented(c, "auth.register — registration is handled by Innovedus Member Center")
|
||||||
|
}
|
||||||
|
}
|
||||||
46
visionA-backend/internal/api/auth_test.go
Normal file
46
visionA-backend/internal/api/auth_test.go
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAuthLogin_OIDCMode_Returns410 驗證 POST /api/auth/login 在 OIDC 模式下回 410。
|
||||||
|
//
|
||||||
|
// OIDC 模式只接受 GET /api/auth/login(redirect flow),POST 一律 410 並指引使用者
|
||||||
|
// 改用 GET。完整 OIDC flow 測試見 oidc_auth_test.go。
|
||||||
|
func TestAuthLogin_OIDCMode_Returns410(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerAuthRoutes(g, Deps{})
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"email":"foo","password":"bar"}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusGone, w.Code, "POST /api/auth/login 應回 410 Gone")
|
||||||
|
assert.Contains(t, w.Body.String(), "GET /api/auth/login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAuthRegister_Returns501 驗證雛形不做註冊(永遠 501)。
|
||||||
|
//
|
||||||
|
// OIDC 模式下註冊由 Member Center 負責,visionA 不接這條。
|
||||||
|
func TestAuthRegister_Returns501(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerAuthRoutes(g, Deps{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/auth/register", nil))
|
||||||
|
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "Member Center")
|
||||||
|
}
|
||||||
25
visionA-backend/internal/api/clusters.go
Normal file
25
visionA-backend/internal/api/clusters.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
// clusters.go — /api/clusters/* handler。
|
||||||
|
//
|
||||||
|
// **Phase 0 / B5 範圍很窄**(對齊 PM 對「叢集推論」Phase 1 才做深入的裁決):
|
||||||
|
// - GET /api/clusters 回空陣列,讓前端能完成基本渲染
|
||||||
|
// - 其他 endpoint 保持 501(由 stubs.go 負責)
|
||||||
|
//
|
||||||
|
// 詳細規格見 api-spec.md §5。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerClusterRoutes 註冊 /api/clusters 的「雛形可呼叫」endpoint。
|
||||||
|
//
|
||||||
|
// 目前只有 GET /api/clusters 回空陣列;其他 endpoint 的 stub 由 stubs.go 註冊
|
||||||
|
// (本 function 只覆寫需要有 body 的那一條,避免衝突)。
|
||||||
|
func registerClusterRoutes(g *gin.RouterGroup, _ Deps) {
|
||||||
|
g.GET("/clusters", func(c *gin.Context) {
|
||||||
|
WriteSuccess(c, http.StatusOK, []any{})
|
||||||
|
})
|
||||||
|
}
|
||||||
292
visionA-backend/internal/api/devices.go
Normal file
292
visionA-backend/internal/api/devices.go
Normal file
@ -0,0 +1,292 @@
|
|||||||
|
// devices.go — /api/devices/* 的 handler 實作。
|
||||||
|
//
|
||||||
|
// 雛形分兩種資料來源:
|
||||||
|
// 1. 純雲端(讀 DeviceRepo):GET /api/devices、GET /api/devices/:id
|
||||||
|
// — 回報使用者已配對的裝置清單,合併即時 tunnel 連線狀態
|
||||||
|
// 2. 走 tunnel proxy(呼叫 local agent):scan / connect / disconnect / flash / inference
|
||||||
|
// — 這些操作實際執行在 local agent(USB 插的那台機器)
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §3 + feature-device-management.md。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerDeviceRoutes 註冊 /api/devices/* 的 routes。
|
||||||
|
func registerDeviceRoutes(g *gin.RouterGroup, deps Deps) {
|
||||||
|
// 純雲端讀取類
|
||||||
|
g.GET("/devices", devicesListHandler(deps))
|
||||||
|
g.GET("/devices/:id", devicesGetHandler(deps))
|
||||||
|
|
||||||
|
// 走 tunnel proxy 的操作類
|
||||||
|
proxy := newProxyHandler(deps, proxyOptions{})
|
||||||
|
g.POST("/devices/scan", proxy)
|
||||||
|
g.POST("/devices/:id/connect", proxy)
|
||||||
|
g.POST("/devices/:id/disconnect", proxy)
|
||||||
|
g.POST("/devices/:id/flash", proxy)
|
||||||
|
g.POST("/devices/:id/inference/start", proxy)
|
||||||
|
g.POST("/devices/:id/inference/stop", proxy)
|
||||||
|
|
||||||
|
// Unpair(雛形實作:軟刪 DeviceRepo + CloseSession)
|
||||||
|
g.POST("/devices/:id/unpair", devicesUnpairHandler(deps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceListItem 是 GET /api/devices 回應中的單筆裝置。
|
||||||
|
//
|
||||||
|
// 合併雲端 DeviceRepo 的 metadata 與 Session 狀態(tunnel_online):
|
||||||
|
type DeviceListItem struct {
|
||||||
|
// 基本 metadata(來自 DeviceRepo)
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
DeviceType string `json:"device_type"`
|
||||||
|
SerialNumber string `json:"serial_number,omitempty"`
|
||||||
|
|
||||||
|
// 狀態
|
||||||
|
RemoteStatus string `json:"remote_status"`
|
||||||
|
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
|
||||||
|
LastConnectedAt *time.Time `json:"last_connected_at,omitempty"`
|
||||||
|
USBStatus string `json:"status"` // USB-level
|
||||||
|
|
||||||
|
// Tunnel 即時狀態(若有)
|
||||||
|
TunnelOnline bool `json:"tunnel_online"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// devicesListHandler 實作 GET /api/devices。
|
||||||
|
//
|
||||||
|
// 行為:從 DeviceRepo 列出當前 user 的裝置,再合併 SessionStore 的 tunnel 狀態:
|
||||||
|
// - 若該 user 有 active session → tunnel_online = true,last_seen_at 從 session 更新
|
||||||
|
// - 無 active session → 仍列出,但 tunnel_online = false
|
||||||
|
//
|
||||||
|
// Phase 1 會改為 DB JOIN + presigned URL;雛形 in-memory 足夠。
|
||||||
|
func devicesListHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.DeviceRepo == nil {
|
||||||
|
WriteSuccess(c, http.StatusOK, []DeviceListItem{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
devices, err := deps.DeviceRepo.List(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"list devices failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查 tunnel 狀態(雛形:列全部 session 找當前 user 的;為空不致命)
|
||||||
|
tunnelAlive, lastSeen := resolveTunnelStatus(ctx, deps.SessionStore, userID)
|
||||||
|
|
||||||
|
out := make([]DeviceListItem, 0, len(devices))
|
||||||
|
for _, d := range devices {
|
||||||
|
item := DeviceListItem{
|
||||||
|
ID: d.ID,
|
||||||
|
Name: d.Name,
|
||||||
|
DeviceType: d.DeviceType,
|
||||||
|
SerialNumber: d.SerialNumber,
|
||||||
|
RemoteStatus: d.RemoteStatus,
|
||||||
|
LastSeenAt: d.LastSeenAt,
|
||||||
|
LastConnectedAt: d.LastConnectedAt,
|
||||||
|
USBStatus: d.Status,
|
||||||
|
TunnelOnline: tunnelAlive,
|
||||||
|
CreatedAt: d.CreatedAt,
|
||||||
|
UpdatedAt: d.UpdatedAt,
|
||||||
|
}
|
||||||
|
// 如果雲端沒記錄 LastSeenAt 但 tunnel 活著,就用 session 的 lastSeen 填
|
||||||
|
if item.LastSeenAt == nil && tunnelAlive && !lastSeen.IsZero() {
|
||||||
|
ls := lastSeen
|
||||||
|
item.LastSeenAt = &ls
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// devicesGetHandler 實作 GET /api/devices/:id。
|
||||||
|
//
|
||||||
|
// 雛形:直接從 DeviceRepo 讀(不走 tunnel)。ownership 檢查以 OwnerUserID 比對。
|
||||||
|
// 若要即時查 USB 狀態,前端可再打 POST /api/devices/:id/connect 等 proxy 端點。
|
||||||
|
func devicesGetHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.DeviceRepo == nil {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "device not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
if id == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "device id required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
d, err := deps.DeviceRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, device.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "device not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get device failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ownership 檢查(雛形單一 user,但仍守住這道)
|
||||||
|
if d.OwnerUserID != userID {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeForbidden,
|
||||||
|
"not owner of this device", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tunnelAlive, lastSeen := resolveTunnelStatus(ctx, deps.SessionStore, userID)
|
||||||
|
item := DeviceListItem{
|
||||||
|
ID: d.ID,
|
||||||
|
Name: d.Name,
|
||||||
|
DeviceType: d.DeviceType,
|
||||||
|
SerialNumber: d.SerialNumber,
|
||||||
|
RemoteStatus: d.RemoteStatus,
|
||||||
|
LastSeenAt: d.LastSeenAt,
|
||||||
|
LastConnectedAt: d.LastConnectedAt,
|
||||||
|
USBStatus: d.Status,
|
||||||
|
TunnelOnline: tunnelAlive,
|
||||||
|
CreatedAt: d.CreatedAt,
|
||||||
|
UpdatedAt: d.UpdatedAt,
|
||||||
|
}
|
||||||
|
if item.LastSeenAt == nil && tunnelAlive && !lastSeen.IsZero() {
|
||||||
|
ls := lastSeen
|
||||||
|
item.LastSeenAt = &ls
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// devicesUnpairHandler 實作 POST /api/devices/:id/unpair。
|
||||||
|
//
|
||||||
|
// 雛形行為:
|
||||||
|
// 1. 驗證 device ownership
|
||||||
|
// 2. 軟刪 DeviceRepo entry
|
||||||
|
// 3. 若該 user 有 active session → 發 CloseSession(best-effort)
|
||||||
|
//
|
||||||
|
// 真正的 Session Token 撤銷(Phase 1)需要 PairingStore/SessionTokenStore 支援。
|
||||||
|
func devicesUnpairHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.DeviceRepo == nil {
|
||||||
|
WriteNotImplemented(c, "device repo not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := c.Param("id")
|
||||||
|
if id == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "device id required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
d, err := deps.DeviceRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, device.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "device not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get device failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if d.OwnerUserID != userID {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeForbidden, "not owner", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 軟刪
|
||||||
|
if err := deps.DeviceRepo.Delete(ctx, id); err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"delete device failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// best-effort:關閉該 user 的 session(雛形單裝置假設)
|
||||||
|
if deps.SessionStore != nil {
|
||||||
|
if token, tokErr := pickActiveSessionToken(ctx, deps.SessionStore, userID, deps.Logger); tokErr == nil {
|
||||||
|
_ = deps.SessionStore.Unregister(ctx, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logOrDefault(deps.Logger).Info("devices: unpaired",
|
||||||
|
"device_id", id,
|
||||||
|
"user_id", userID,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, gin.H{"id": id, "unpaired": true})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveTunnelStatus 回報當前 user 是否有 active tunnel,以及最新心跳時間。
|
||||||
|
//
|
||||||
|
// 雛形單裝置假設:只看第一筆 match 的 session。多裝置時 Phase 1 擴充。
|
||||||
|
// 失敗一律 return (false, zero time) 不 raise — 給 list/get 用,不該因此 fail。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security audit M2:寬鬆比對暫保留待人工介入。
|
||||||
|
// 詳細理由見 pickActiveSessionToken 註解:relay 端 LocalHandle.Summary 不帶 UserID。
|
||||||
|
// 修復 caller (handler) 已先做 strict UserContext 檢查,userID 必非空。
|
||||||
|
func resolveTunnelStatus(ctx context.Context, store session.Store, userID string) (bool, time.Time) {
|
||||||
|
if store == nil || userID == "" {
|
||||||
|
return false, time.Time{}
|
||||||
|
}
|
||||||
|
summaries, err := store.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, time.Time{}
|
||||||
|
}
|
||||||
|
for _, s := range summaries {
|
||||||
|
// 寬鬆比對:暫接受 s.UserID == "" 直到 relay 端 backfill UserID(M2 待人工介入)。
|
||||||
|
if s.UserID == "" || s.UserID == userID {
|
||||||
|
return true, s.LastHeartbeat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false, time.Time{}
|
||||||
|
}
|
||||||
116
visionA-backend/internal/api/devices_test.go
Normal file
116
visionA-backend/internal/api/devices_test.go
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newDevicesFixture 建立 router 並塞好必要依賴(InMemory repo + fakeSessionStore)。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security fix C1:移除 Deps.StaticUserID(見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)。
|
||||||
|
// 改由 injectStaticUserContext 顯式注入 UserContext,handler 強制要求 UserContext 非空。
|
||||||
|
func newDevicesFixture(t *testing.T, sessions []any) *gin.Engine {
|
||||||
|
t.Helper()
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
_ = sessions // 暫用,下方 helper 內建
|
||||||
|
registerDeviceRoutes(g, Deps{
|
||||||
|
DeviceRepo: device.NewInMemoryRepository(),
|
||||||
|
SessionStore: &fakeSessionStore{}, // 無 session
|
||||||
|
})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDevicesList_Empty 驗證沒 device 時回空陣列。
|
||||||
|
func TestDevicesList_Empty(t *testing.T) {
|
||||||
|
r := newDevicesFixture(t, nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/devices", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var sb SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &sb))
|
||||||
|
arr, ok := sb.Data.([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Empty(t, arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDevicesList_ReturnsOwnDevicesOnly 驗證只回當前 user 的 device。
|
||||||
|
func TestDevicesList_ReturnsOwnDevicesOnly(t *testing.T) {
|
||||||
|
repo := device.NewInMemoryRepository()
|
||||||
|
ctx := context.Background()
|
||||||
|
now := time.Now().UTC()
|
||||||
|
require.NoError(t, repo.Save(ctx, &device.Device{
|
||||||
|
ID: "mine", OwnerUserID: "demo-user", Name: "A", DeviceType: "kl520",
|
||||||
|
RemoteStatus: device.RemoteStatusOnline, Status: device.USBStatusOnline,
|
||||||
|
CreatedAt: now,
|
||||||
|
}))
|
||||||
|
require.NoError(t, repo.Save(ctx, &device.Device{
|
||||||
|
ID: "theirs", OwnerUserID: "other", Name: "B", DeviceType: "kl520",
|
||||||
|
CreatedAt: now,
|
||||||
|
}))
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerDeviceRoutes(g, Deps{
|
||||||
|
DeviceRepo: repo,
|
||||||
|
SessionStore: &fakeSessionStore{},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/devices", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var sb SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &sb))
|
||||||
|
arr := sb.Data.([]any)
|
||||||
|
require.Len(t, arr, 1, "只應看到自己的 device")
|
||||||
|
first := arr[0].(map[string]any)
|
||||||
|
assert.Equal(t, "mine", first["id"])
|
||||||
|
assert.Equal(t, false, first["tunnel_online"], "沒 session → tunnel_online=false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDevicesGet_NotOwner 驗證非 owner 被擋 403。
|
||||||
|
func TestDevicesGet_NotOwner(t *testing.T) {
|
||||||
|
repo := device.NewInMemoryRepository()
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &device.Device{
|
||||||
|
ID: "x", OwnerUserID: "other", Name: "a", DeviceType: "kl520",
|
||||||
|
}))
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerDeviceRoutes(g, Deps{
|
||||||
|
DeviceRepo: repo,
|
||||||
|
SessionStore: &fakeSessionStore{},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/devices/x", nil))
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDevicesGet_NotFound 驗證不存在回 404。
|
||||||
|
func TestDevicesGet_NotFound(t *testing.T) {
|
||||||
|
r := newDevicesFixture(t, nil)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/devices/ghost", nil))
|
||||||
|
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||||
|
}
|
||||||
86
visionA-backend/internal/api/errors.go
Normal file
86
visionA-backend/internal/api/errors.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 錯誤碼常數 — 對齊 api-spec.md §11。
|
||||||
|
const (
|
||||||
|
ErrCodeUnauthorized = "UNAUTHORIZED"
|
||||||
|
ErrCodeForbidden = "FORBIDDEN"
|
||||||
|
ErrCodeNotFound = "NOT_FOUND"
|
||||||
|
ErrCodeValidationFailed = "VALIDATION_FAILED"
|
||||||
|
ErrCodeTunnelDisconnect = "TUNNEL_DISCONNECTED"
|
||||||
|
ErrCodeTunnelError = "TUNNEL_ERROR"
|
||||||
|
ErrCodeNotImplemented = "NOT_IMPLEMENTED"
|
||||||
|
ErrCodeRateLimited = "RATE_LIMITED"
|
||||||
|
ErrCodeInternalError = "INTERNAL_ERROR"
|
||||||
|
// ErrCodePayloadTooLarge 對齊 HTTP 413(例:模型上傳超過 MaxUploadSizeMB)。
|
||||||
|
ErrCodePayloadTooLarge = "PAYLOAD_TOO_LARGE"
|
||||||
|
// ErrCodeInvalidSignature 用於 /storage/* 驗簽失敗 / URL 過期。
|
||||||
|
ErrCodeInvalidSignature = "INVALID_SIGNATURE"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorBody 是 API 錯誤回應的 envelope 結構。
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md:
|
||||||
|
//
|
||||||
|
// { "success": false, "error": { "code": "...", "message": "...", "request_id": "..." } }
|
||||||
|
//
|
||||||
|
// 為什麼用 envelope 而非裸 error:方便前端統一處理 + 與成功回應形狀一致。
|
||||||
|
type ErrorBody struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Error *ErrorDetail `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorDetail 是錯誤的具體資訊。
|
||||||
|
type ErrorDetail struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Details []FieldError `json:"details,omitempty"` // 例如 validation 細節
|
||||||
|
RequestID string `json:"request_id,omitempty"`
|
||||||
|
Extra map[string]any `json:"extra,omitempty"` // 給 specific error 帶結構化資料
|
||||||
|
}
|
||||||
|
|
||||||
|
// FieldError 描述單一欄位的驗證錯誤。
|
||||||
|
type FieldError struct {
|
||||||
|
Field string `json:"field"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SuccessBody 是成功回應的 envelope。
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md:`{ "success": true, "data": ... }`。
|
||||||
|
type SuccessBody struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
Data any `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteError 統一寫錯誤回應(會自動帶上 request_id)。
|
||||||
|
//
|
||||||
|
// 注意:呼叫後 caller 仍需自行 c.Abort()(如果是在 middleware 中要終止 chain);
|
||||||
|
// 在 handler 中只需 return 即可。
|
||||||
|
func WriteError(c *gin.Context, status int, code, message string, details []FieldError) {
|
||||||
|
c.JSON(status, ErrorBody{
|
||||||
|
Success: false,
|
||||||
|
Error: &ErrorDetail{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Details: details,
|
||||||
|
RequestID: RequestIDFrom(c),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteSuccess 統一寫成功回應。
|
||||||
|
func WriteSuccess(c *gin.Context, status int, data any) {
|
||||||
|
c.JSON(status, SuccessBody{
|
||||||
|
Success: true,
|
||||||
|
Data: data,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteNotImplemented 回應 501,給 B5 還沒實作的 handler 用。
|
||||||
|
func WriteNotImplemented(c *gin.Context, hint string) {
|
||||||
|
WriteError(c, 501, ErrCodeNotImplemented, hint, nil)
|
||||||
|
}
|
||||||
124
visionA-backend/internal/api/health.go
Normal file
124
visionA-backend/internal/api/health.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HealthzHandler 是 K8s liveness / readiness 用的最小健康檢查。
|
||||||
|
//
|
||||||
|
// 不檢查任何依賴(remote-proxy、DB),只代表 process 還活著。
|
||||||
|
// readiness 想檢查依賴的話應該用 /api/system/health。
|
||||||
|
func HealthzHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerSystemRoutes 註冊 /api/system/* 的 routes。
|
||||||
|
//
|
||||||
|
// MVP 範圍(B4 + B5):
|
||||||
|
// - GET /api/system/health → 回 api-server 自己 + tunnel 連線狀態
|
||||||
|
// - GET /api/system/info → 回版本資訊(雛形 hard-coded)
|
||||||
|
// - GET /api/system/deps → 走 tunnel proxy 查 local agent 的依賴狀態(B5)
|
||||||
|
func registerSystemRoutes(g *gin.RouterGroup, deps Deps) {
|
||||||
|
g.GET("/system/health", systemHealthHandler(deps))
|
||||||
|
g.GET("/system/info", systemInfoHandler())
|
||||||
|
// /api/system/deps 透過 tunnel proxy 到 local agent 的同路徑。
|
||||||
|
g.GET("/system/deps", newProxyHandler(deps, proxyOptions{}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemHealthResponse 是 GET /api/system/health 的 data payload。
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §7:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "api_server": "ok",
|
||||||
|
// "tunnel_connected": true,
|
||||||
|
// "agent_last_seen_at": "..."
|
||||||
|
// }
|
||||||
|
type SystemHealthResponse struct {
|
||||||
|
APIServer string `json:"api_server"`
|
||||||
|
TunnelConnected bool `json:"tunnel_connected"`
|
||||||
|
AgentLastSeenAt *time.Time `json:"agent_last_seen_at,omitempty"`
|
||||||
|
AgentSessionCount int `json:"agent_session_count"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// systemHealthHandler 回報 api-server + tunnel 狀態。
|
||||||
|
//
|
||||||
|
// 「tunnel_connected」的判定方式:呼叫 SessionStore.List。若有任一 session
|
||||||
|
// 在線就視為 connected。雛形是單一 user 場景,所以這個語義足以呈現「我這邊
|
||||||
|
// 有沒有 agent 連著」;多 user / 多 device 階段會改成 per-user 查詢(B5)。
|
||||||
|
func systemHealthHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
resp := SystemHealthResponse{
|
||||||
|
APIServer: "ok",
|
||||||
|
}
|
||||||
|
|
||||||
|
if deps.SessionStore != nil {
|
||||||
|
// 給一個短 timeout — health 檢查不該卡住整個 request
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
summaries, err := deps.SessionStore.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// SessionStore 失敗不致命;只回 tunnel_connected=false 加 warning
|
||||||
|
logOrDefault(deps.Logger).Warn("system/health: list sessions failed",
|
||||||
|
"error", err,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
} else {
|
||||||
|
resp.AgentSessionCount = len(summaries)
|
||||||
|
if len(summaries) > 0 {
|
||||||
|
resp.TunnelConnected = true
|
||||||
|
// 取最新一個 LastHeartbeat 作為 agent_last_seen_at
|
||||||
|
var latest time.Time
|
||||||
|
for _, s := range summaries {
|
||||||
|
if s.LastHeartbeat.After(latest) {
|
||||||
|
latest = s.LastHeartbeat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !latest.IsZero() {
|
||||||
|
resp.AgentLastSeenAt = &latest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SystemInfoResponse 是 GET /api/system/info 的 data payload。
|
||||||
|
type SystemInfoResponse struct {
|
||||||
|
Service string `json:"service"`
|
||||||
|
Version string `json:"version"`
|
||||||
|
Phase string `json:"phase"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// logOrDefault 是 nil-safe slog 取用 helper;給 handler 共用。
|
||||||
|
//
|
||||||
|
// Deps.validate 已會把 nil logger fallback 到 slog.Default,但測試直接呼叫
|
||||||
|
// register*Routes 時可能跳過 validate;這個 helper 讓 handler 不必每處都 nil 檢查。
|
||||||
|
func logOrDefault(l *slog.Logger) *slog.Logger {
|
||||||
|
if l == nil {
|
||||||
|
return slog.Default()
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// systemInfoHandler 回報版本與環境階段。
|
||||||
|
//
|
||||||
|
// 雛形版本字串 hard-coded;B6(CI/CD)會改用 build flag 注入 git commit hash。
|
||||||
|
func systemInfoHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
WriteSuccess(c, http.StatusOK, SystemInfoResponse{
|
||||||
|
Service: "visiona-api-server",
|
||||||
|
Version: "0.0.0-phase0",
|
||||||
|
Phase: "phase-0-prototype",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
124
visionA-backend/internal/api/health_test.go
Normal file
124
visionA-backend/internal/api/health_test.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeSessionStore 是測試用 Store 實作,只回 List 結果;其他方法 panic 表示
|
||||||
|
// 不應被呼叫(以利早期偵錯)。
|
||||||
|
type fakeSessionStore struct {
|
||||||
|
sessions []*session.Summary
|
||||||
|
listErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeSessionStore) Register(context.Context, string, session.Handle) error {
|
||||||
|
panic("Register should not be called")
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) Unregister(context.Context, string) error {
|
||||||
|
panic("Unregister should not be called")
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) Lookup(context.Context, string) (session.Handle, error) {
|
||||||
|
panic("Lookup should not be called")
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) Exists(context.Context, string) (bool, error) {
|
||||||
|
panic("Exists should not be called")
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) List(context.Context) ([]*session.Summary, error) {
|
||||||
|
return f.sessions, f.listErr
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) Heartbeat(context.Context, string) error {
|
||||||
|
panic("Heartbeat should not be called")
|
||||||
|
}
|
||||||
|
func (f *fakeSessionStore) CleanupExpired(context.Context, time.Duration) (int, error) {
|
||||||
|
panic("CleanupExpired should not be called")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHealthzHandler 驗證 /healthz 回 200 + status:ok。
|
||||||
|
func TestHealthzHandler(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.GET("/healthz", HealthzHandler())
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/healthz", nil))
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), `"status":"ok"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSystemHealth_TunnelDisconnected 驗證沒 session 時回 connected=false。
|
||||||
|
func TestSystemHealth_TunnelDisconnected(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerSystemRoutes(g, Deps{
|
||||||
|
SessionStore: &fakeSessionStore{sessions: nil},
|
||||||
|
Logger: nil,
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/system/health", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
require.True(t, body.Success)
|
||||||
|
data, _ := body.Data.(map[string]any)
|
||||||
|
require.NotNil(t, data)
|
||||||
|
assert.Equal(t, "ok", data["api_server"])
|
||||||
|
assert.Equal(t, false, data["tunnel_connected"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSystemHealth_TunnelConnected 驗證有 session 時回 connected=true 並帶 last_seen_at。
|
||||||
|
func TestSystemHealth_TunnelConnected(t *testing.T) {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerSystemRoutes(g, Deps{
|
||||||
|
SessionStore: &fakeSessionStore{
|
||||||
|
sessions: []*session.Summary{
|
||||||
|
{Token: "vAc_a", LastHeartbeat: now.Add(-5 * time.Second)},
|
||||||
|
{Token: "vAc_b", LastHeartbeat: now}, // 最新
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/system/health", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
data := body.Data.(map[string]any)
|
||||||
|
assert.Equal(t, true, data["tunnel_connected"])
|
||||||
|
assert.EqualValues(t, 2, data["agent_session_count"])
|
||||||
|
assert.NotEmpty(t, data["agent_last_seen_at"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSystemInfo 驗證 GET /api/system/info 回基本欄位。
|
||||||
|
func TestSystemInfo(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerSystemRoutes(g, Deps{})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/system/info", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
data := body.Data.(map[string]any)
|
||||||
|
assert.Equal(t, "visiona-api-server", data["service"])
|
||||||
|
assert.Equal(t, "phase-0-prototype", data["phase"])
|
||||||
|
}
|
||||||
280
visionA-backend/internal/api/middleware.go
Normal file
280
visionA-backend/internal/api/middleware.go
Normal file
@ -0,0 +1,280 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/cors"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
// gin context key 常數 — 集中管理避免拼寫錯誤。
|
||||||
|
const (
|
||||||
|
// ctxKeyUserContext 是儲存 *auth.UserContext 的 gin context key。
|
||||||
|
ctxKeyUserContext = "auth.userContext"
|
||||||
|
// ctxKeyRequestID 是請求追蹤 ID 的 gin context key(同時也會寫到 response header)。
|
||||||
|
ctxKeyRequestID = "request.id"
|
||||||
|
// ctxKeyUserSession 是儲存 OIDC 模式下 *usersession.Session 的 gin context key。
|
||||||
|
// 由 AuthMiddleware 設定;handler 可選用以避免再次 lookup。
|
||||||
|
ctxKeyUserSession = "auth.userSession"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestIDMiddleware 給每個 request 產生 UUID 作為追蹤 ID。
|
||||||
|
//
|
||||||
|
// 行為:
|
||||||
|
// - 若 request 帶 X-Request-ID header,直接沿用(讓上游 LB / mesh 串起來)
|
||||||
|
// - 否則產生新的 UUID v4
|
||||||
|
// - 寫到 gin.Context(給 logger / handler 用)+ response header
|
||||||
|
func RequestIDMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
rid := c.GetHeader("X-Request-ID")
|
||||||
|
if rid == "" {
|
||||||
|
rid = uuid.NewString()
|
||||||
|
}
|
||||||
|
c.Set(ctxKeyRequestID, rid)
|
||||||
|
c.Writer.Header().Set("X-Request-ID", rid)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggerMiddleware 用結構化 slog 記錄每個請求的關鍵欄位。
|
||||||
|
//
|
||||||
|
// 對齊 backend/CLAUDE.md §6.1 的結構化日誌要求:
|
||||||
|
// - timestamp、level、service:由 logger 預設帶
|
||||||
|
// - request_id、http_method、http_path、http_status、duration_ms
|
||||||
|
// - user_id:若 AuthMiddleware 已執行則一併帶上
|
||||||
|
//
|
||||||
|
// logger 為 nil 時 fallback 到 slog.Default — 對 test fixture 友善。
|
||||||
|
func LoggerMiddleware(logger *slog.Logger) gin.HandlerFunc {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
start := time.Now()
|
||||||
|
path := c.Request.URL.Path
|
||||||
|
c.Next()
|
||||||
|
duration := time.Since(start)
|
||||||
|
|
||||||
|
// 取出 request ID / user ID(若有)
|
||||||
|
rid, _ := c.Get(ctxKeyRequestID)
|
||||||
|
var userID string
|
||||||
|
if uc, ok := UserContextFrom(c); ok {
|
||||||
|
userID = uc.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根據 status code 決定 log level。
|
||||||
|
// 501 NOT_IMPLEMENTED 是「刻意設計」的回應,不應該觸發 error 告警 → 降為 INFO。
|
||||||
|
status := c.Writer.Status()
|
||||||
|
level := slog.LevelInfo
|
||||||
|
switch {
|
||||||
|
case status == http.StatusNotImplemented:
|
||||||
|
level = slog.LevelInfo
|
||||||
|
case status >= 500:
|
||||||
|
level = slog.LevelError
|
||||||
|
case status >= 400:
|
||||||
|
level = slog.LevelWarn
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.LogAttrs(c.Request.Context(), level, "http request",
|
||||||
|
slog.String("request_id", asString(rid)),
|
||||||
|
slog.String("user_id", userID),
|
||||||
|
slog.String("action", "http.request"),
|
||||||
|
slog.String("http_method", c.Request.Method),
|
||||||
|
slog.String("http_path", path),
|
||||||
|
slog.Int("http_status", status),
|
||||||
|
slog.Int64("duration_ms", duration.Milliseconds()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecoveryMiddleware 攔截 handler panic,記錄並回 500。
|
||||||
|
//
|
||||||
|
// 不直接用 gin.Recovery() 是因為要走我們統一的 JSON error 格式。
|
||||||
|
// logger 為 nil 時 fallback 到 slog.Default — 這條路徑在測試環境會被觸發。
|
||||||
|
func RecoveryMiddleware(logger *slog.Logger) gin.HandlerFunc {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
logger.Error("panic recovered",
|
||||||
|
"error", rec,
|
||||||
|
"path", c.Request.URL.Path,
|
||||||
|
"method", c.Request.Method)
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "internal server error", nil)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSMiddleware 用 gin-contrib/cors 設定 CORS 規則。
|
||||||
|
//
|
||||||
|
// 預設允許 http://localhost:3000(前端 dev server);
|
||||||
|
// allowedOrigins 為空則 fallback 到該預設值。生產環境應由 caller 注入正式網域。
|
||||||
|
//
|
||||||
|
// 允許的 method / header 對齊一般 REST API 需求;不開放 wildcard '*' Origin 以
|
||||||
|
// 避免「攜帶 cookie 的 cross-origin 請求」被瀏覽器擋下。
|
||||||
|
func CORSMiddleware(allowedOrigins []string) gin.HandlerFunc {
|
||||||
|
if len(allowedOrigins) == 0 {
|
||||||
|
allowedOrigins = []string{"http://localhost:3000"}
|
||||||
|
}
|
||||||
|
return cors.New(cors.Config{
|
||||||
|
AllowOrigins: allowedOrigins,
|
||||||
|
AllowMethods: []string{
|
||||||
|
http.MethodGet, http.MethodPost, http.MethodPut, http.MethodPatch,
|
||||||
|
http.MethodDelete, http.MethodOptions,
|
||||||
|
},
|
||||||
|
AllowHeaders: []string{
|
||||||
|
"Origin", "Content-Type", "Accept", "Authorization",
|
||||||
|
"X-Request-ID", "X-Idempotency-Key",
|
||||||
|
},
|
||||||
|
ExposeHeaders: []string{
|
||||||
|
"X-Request-ID", "X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset",
|
||||||
|
},
|
||||||
|
AllowCredentials: true,
|
||||||
|
MaxAge: 12 * time.Hour,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthMiddleware 從 cookie 解析 OIDC session 並把 UserContext 放進 gin.Context。
|
||||||
|
//
|
||||||
|
// OB5(2026-04-26)起,OIDC 是唯一認證路徑:
|
||||||
|
// - 從 cookie 讀 session ID → SessionManager.GetSession
|
||||||
|
// - 必須是「已登入 session」(UserID 非空;空代表只是 OIDC pending session)
|
||||||
|
// - 注入 UserContext + Session 到 gin.Context
|
||||||
|
//
|
||||||
|
// 任何失敗一律 401 UNAUTHORIZED,由 frontend 處理 redirect 到 /api/auth/login。
|
||||||
|
//
|
||||||
|
// SessionManager 已由 NewRouter 的 validate() 確保非 nil(缺則啟動時就 panic)—
|
||||||
|
// 因此此 middleware 不需 nil check。
|
||||||
|
func AuthMiddleware(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
sess, err := deps.SessionManager.GetSession(c.Request.Context(), c.Request)
|
||||||
|
if err != nil {
|
||||||
|
// no session / cookie 過期 / store 找不到 → 一律 401
|
||||||
|
WriteError(c, http.StatusUnauthorized, ErrCodeUnauthorized, "no_session", nil)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 區分「pending session」(OIDC dance 進行中、UserID 還是空)vs「已登入 session」。
|
||||||
|
//
|
||||||
|
// ⚠️ 安全臨界檢查:此判斷是 ADR-012「合一 cookie 設計」的核心防線,**不可拿掉、不可放寬**。
|
||||||
|
//
|
||||||
|
// 背景:oidc-tdd.md §4.5 原設計為 pending 與 logged-in 各一個 cookie;OB2 / OB4 為了
|
||||||
|
// 簡化實作合一(兩種狀態共用 visiona_session cookie + 同一個 store record,由 UserID 是否
|
||||||
|
// 為空判斷階段)。合一設計的安全前提是「protected endpoint 在 middleware 層強制檢查
|
||||||
|
// UserID 非空」——若刪掉這個檢查,攻擊者拿到 pending session cookie(自己跑 /api/auth/login
|
||||||
|
// 即得)就能直接訪問所有 protected endpoint,雖然 UserID 為空看不到資料,但側通道風險
|
||||||
|
// 與後續 handler 的健壯性都成問題。
|
||||||
|
//
|
||||||
|
// 配套防護:login callback 完成時呼叫 SessionManager.RotateSessionID 換 ID(Fix-A1),
|
||||||
|
// pending 階段的舊 cookie 從此失效,搭配此處檢查雙保險。
|
||||||
|
//
|
||||||
|
// 詳見:.autoflow/04-architecture/adr/adr-012-pending-session-shared-cookie.md
|
||||||
|
if sess.UserID == "" {
|
||||||
|
WriteError(c, http.StatusUnauthorized, ErrCodeUnauthorized, "session_not_authenticated", nil)
|
||||||
|
c.Abort()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(ctxKeyUserContext, &auth.UserContext{
|
||||||
|
UserID: sess.UserID,
|
||||||
|
Email: sess.Email,
|
||||||
|
// Roles / OrgID 雛形未實作(Member Center 不回傳)
|
||||||
|
})
|
||||||
|
// 把 session 也放進 context,handler(如 /auth/me)可避免再次 lookup。
|
||||||
|
c.Set(ctxKeyUserSession, sess)
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserContextFrom 從 gin.Context 取出 *auth.UserContext。
|
||||||
|
// 第二個 return 為 false 表示 AuthMiddleware 未執行或解析失敗。
|
||||||
|
func UserContextFrom(c *gin.Context) (*auth.UserContext, bool) {
|
||||||
|
v, exists := c.Get(ctxKeyUserContext)
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
uc, ok := v.(*auth.UserContext)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return uc, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSessionFrom 從 gin.Context 取出 OIDC 模式下的 *usersession.Session。
|
||||||
|
// 第二個 return 為 false 表示 AuthMiddleware 未執行或 session 缺失。
|
||||||
|
//
|
||||||
|
// 用途:handler(例如 /api/auth/me)想拿 Email / Name 等 session 額外欄位時,
|
||||||
|
// 可以避免重複 cookie + store lookup(已由 middleware 完成)。
|
||||||
|
func UserSessionFrom(c *gin.Context) (*usersession.Session, bool) {
|
||||||
|
v, exists := c.Get(ctxKeyUserSession)
|
||||||
|
if !exists {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
sess, ok := v.(*usersession.Session)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return sess, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestIDFrom 從 gin.Context 取出本次請求的 request ID。
|
||||||
|
func RequestIDFrom(c *gin.Context) string {
|
||||||
|
v, exists := c.Get(ctxKeyRequestID)
|
||||||
|
if !exists {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return asString(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorMiddleware 把 handler 透過 c.Error() 推上來的錯誤統一處理。
|
||||||
|
//
|
||||||
|
// 目前的實作:若 handler 已經寫了 response(c.Writer.Written()),就不覆蓋;
|
||||||
|
// 否則寫一個泛用的 500。後續 B5 可以擴充對特定 error type 做客製化轉換。
|
||||||
|
func ErrorMiddleware() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Next()
|
||||||
|
if len(c.Errors) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.Writer.Written() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 取最後一個 error 作為主要訊息(gin 慣例)
|
||||||
|
last := c.Errors.Last()
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, last.Error(), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// asString 把 any 安全轉成 string(gin context 取值常用)。
|
||||||
|
func asString(v any) string {
|
||||||
|
if v == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if s, ok := v.(string); ok {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// StripBearerPrefix 從 Authorization header 取出 token;不是 Bearer 開頭則回原值。
|
||||||
|
//
|
||||||
|
// OIDC 路徑下 visionA-backend 不再透過 Authorization header 接 token,但保留此 helper
|
||||||
|
// 給未來 service-to-service auth(例:Phase 1 backup local provider、external API
|
||||||
|
// integration)直接複用。
|
||||||
|
func StripBearerPrefix(authHeader string) string {
|
||||||
|
const prefix = "Bearer "
|
||||||
|
if strings.HasPrefix(authHeader, prefix) {
|
||||||
|
return strings.TrimPrefix(authHeader, prefix)
|
||||||
|
}
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
131
visionA-backend/internal/api/middleware_test.go
Normal file
131
visionA-backend/internal/api/middleware_test.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTestSessionManager 是給 middleware_test 用的最小 SessionManager fixture。
|
||||||
|
//
|
||||||
|
// OB5 起 AuthMiddleware 必須有 SessionManager — Static fallback 已拔除。
|
||||||
|
func newTestSessionManager() *usersession.Manager {
|
||||||
|
return usersession.NewManager(usersession.NewInMemoryStore(), usersession.CookieConfig{
|
||||||
|
Name: "visiona_session",
|
||||||
|
Path: "/",
|
||||||
|
HTTPOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: 86400,
|
||||||
|
SigningKey: []byte("middleware-test-signing-key-32b-aa"),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequestIDMiddleware_GeneratesNew 驗證沒帶 header 時會產生新的 request id。
|
||||||
|
func TestRequestIDMiddleware_GeneratesNew(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.GET("/", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, RequestIDFrom(c))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
body := w.Body.String()
|
||||||
|
assert.NotEmpty(t, body, "request id 應寫入 context")
|
||||||
|
assert.Equal(t, body, w.Header().Get("X-Request-ID"), "header 應與 context 一致")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRequestIDMiddleware_PreservesIncoming 驗證帶 header 時會沿用。
|
||||||
|
func TestRequestIDMiddleware_PreservesIncoming(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.GET("/", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, RequestIDFrom(c))
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.Header.Set("X-Request-ID", "upstream-123")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, "upstream-123", w.Body.String())
|
||||||
|
assert.Equal(t, "upstream-123", w.Header().Get("X-Request-ID"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAuthMiddleware_NoCookie_Rejects 驗證沒 cookie 時 → 401 + "no_session"。
|
||||||
|
func TestAuthMiddleware_NoCookie_Rejects(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(AuthMiddleware(Deps{SessionManager: newTestSessionManager()}))
|
||||||
|
r.GET("/", func(c *gin.Context) {
|
||||||
|
c.String(http.StatusOK, "should not reach")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "no_session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 註:AuthMiddleware 的「pending session → 401」與「authenticated session → 通過」的
|
||||||
|
// 完整測試在 oidc_auth_test.go(TestOIDCMiddleware_Allows_AuthenticatedSession /
|
||||||
|
// TestOIDCMiddleware_Rejects_PendingSession),因為需要走完整 login flow 才能模擬。
|
||||||
|
|
||||||
|
// TestRecoveryMiddleware_CatchesPanic 驗證 handler panic 會被攔成 500 + INTERNAL_ERROR。
|
||||||
|
func TestRecoveryMiddleware_CatchesPanic(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(RecoveryMiddleware(nil))
|
||||||
|
r.GET("/", func(c *gin.Context) {
|
||||||
|
panic("boom")
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeInternalError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestStripBearerPrefix 驗證 Bearer token prefix 處理。
|
||||||
|
func TestStripBearerPrefix(t *testing.T) {
|
||||||
|
assert.Equal(t, "abc123", StripBearerPrefix("Bearer abc123"))
|
||||||
|
assert.Equal(t, "abc123", StripBearerPrefix("abc123"))
|
||||||
|
assert.Equal(t, "", StripBearerPrefix(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCORSMiddleware_AllowsConfiguredOrigin 驗證只放行白名單 Origin。
|
||||||
|
func TestCORSMiddleware_AllowsConfiguredOrigin(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(CORSMiddleware([]string{"http://localhost:3000"}))
|
||||||
|
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "ok") })
|
||||||
|
|
||||||
|
// Allowed origin
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
req.Header.Set("Origin", "http://localhost:3000")
|
||||||
|
req.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
assert.True(t, strings.Contains(w.Header().Get("Access-Control-Allow-Origin"), "localhost:3000"),
|
||||||
|
"預期 Allow-Origin 包含 localhost:3000,實際 header: %v", w.Header())
|
||||||
|
|
||||||
|
// Disallowed origin
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2 := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
req2.Header.Set("Origin", "http://evil.example")
|
||||||
|
req2.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
assert.NotContains(t, w2.Header().Get("Access-Control-Allow-Origin"), "evil.example")
|
||||||
|
}
|
||||||
433
visionA-backend/internal/api/models.go
Normal file
433
visionA-backend/internal/api/models.go
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
// models.go — /api/models/* 的 handler 實作。
|
||||||
|
//
|
||||||
|
// 雛形重點:
|
||||||
|
// - GET /api/models:列當前 user 的模型(ModelRepo in-memory)
|
||||||
|
// - GET /api/models/:id:取單一模型 metadata
|
||||||
|
// - POST /api/models/init:兩階段上傳第一步 — 驗證輸入、產 storageKey 與 presigned PUT URL
|
||||||
|
// - POST /api/models/:id/finalize:第二步 — 驗證檔案已存在(storage.Exists)與大小,標為 ready
|
||||||
|
// - DELETE /api/models/:id:軟刪
|
||||||
|
//
|
||||||
|
// **兩階段上傳(Init → PUT → Finalize)的設計理由**:
|
||||||
|
// - 讓前端直接 PUT 到 storage,不佔 api-server 記憶體 / bandwidth
|
||||||
|
// - Phase 0 LocalFS 用假 presigned URL;Phase 1 S3 用原生 presigned URL
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §4、feature-model-management.md。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// modelUploadURLTTL 是 presigned PUT URL 的存活時間。
|
||||||
|
const modelUploadURLTTL = 15 * time.Minute
|
||||||
|
|
||||||
|
// registerModelRoutes 註冊 /api/models/* 的 routes。
|
||||||
|
func registerModelRoutes(g *gin.RouterGroup, deps Deps) {
|
||||||
|
g.GET("/models", modelsListHandler(deps))
|
||||||
|
g.GET("/models/:id", modelsGetHandler(deps))
|
||||||
|
g.POST("/models/init", modelsInitUploadHandler(deps))
|
||||||
|
g.POST("/models/:id/finalize", modelsFinalizeHandler(deps))
|
||||||
|
g.DELETE("/models/:id", modelsDeleteHandler(deps))
|
||||||
|
|
||||||
|
// load-to-device 雛形先 stub(完整實作需要 presigned GET + 透過 tunnel 送指令給 local agent)
|
||||||
|
g.POST("/models/:id/load-to-device", func(c *gin.Context) {
|
||||||
|
WriteNotImplemented(c, "models.load-to-device — pending Phase 1")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelResponse 是 API 回傳的 Model DTO;對應 api-spec.md §4 的格式。
|
||||||
|
type ModelResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
TargetChip string `json:"target_chip,omitempty"`
|
||||||
|
FileSize int64 `json:"file_size"`
|
||||||
|
Source string `json:"source"`
|
||||||
|
Status string `json:"status"` // "pending" / "ready"
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
UploadedAt *time.Time `json:"uploaded_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// toModelResponse 把 domain model 轉為 API DTO;「status」由 UploadedAt 是否 set 決定。
|
||||||
|
func toModelResponse(m *model.Model) ModelResponse {
|
||||||
|
status := "pending"
|
||||||
|
if m.UploadedAt != nil {
|
||||||
|
status = "ready"
|
||||||
|
}
|
||||||
|
return ModelResponse{
|
||||||
|
ID: m.ID,
|
||||||
|
Name: m.Name,
|
||||||
|
Description: m.Description,
|
||||||
|
TargetChip: m.TargetChip,
|
||||||
|
FileSize: m.FileSize,
|
||||||
|
Source: m.Source,
|
||||||
|
Status: status,
|
||||||
|
CreatedAt: m.CreatedAt,
|
||||||
|
UpdatedAt: m.UpdatedAt,
|
||||||
|
UploadedAt: m.UploadedAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelsListHandler 實作 GET /api/models。
|
||||||
|
func modelsListHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.ModelRepo == nil {
|
||||||
|
WriteSuccess(c, http.StatusOK, []ModelResponse{})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
models, err := deps.ModelRepo.List(ctx, model.ListFilter{OwnerUserID: userID})
|
||||||
|
if err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"list models failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := make([]ModelResponse, 0, len(models))
|
||||||
|
for _, m := range models {
|
||||||
|
out = append(out, toModelResponse(m))
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelsGetHandler 實作 GET /api/models/:id。
|
||||||
|
func modelsGetHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.ModelRepo == nil {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "model not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := c.Param("id")
|
||||||
|
if id == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "model id required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
m, err := deps.ModelRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, model.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "model not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.OwnerUserID != userID {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeForbidden, "not owner", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, toModelResponse(m))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInitRequest 是 POST /api/models/init 的 request body。
|
||||||
|
type ModelInitRequest struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
FileSize int64 `json:"file_size"`
|
||||||
|
Checksum string `json:"checksum,omitempty"` // sha256 hex(Phase 1 驗)
|
||||||
|
TargetChip string `json:"target_chip,omitempty"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelInitResponse 是 POST /api/models/init 的 response data。
|
||||||
|
type ModelInitResponse struct {
|
||||||
|
ModelID string `json:"model_id"`
|
||||||
|
UploadURL string `json:"upload_url"`
|
||||||
|
UploadExpiresAt time.Time `json:"upload_expires_at"`
|
||||||
|
StorageKey string `json:"storage_key"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelsInitUploadHandler 實作 POST /api/models/init。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 驗證 request(name 必填、file_size 不能超過配置)
|
||||||
|
// 2. 產新 model_id + storage_key(`models/{userID}/{modelID}.nef`)
|
||||||
|
// 3. 用 storage.PresignedPutURL 取 PUT URL
|
||||||
|
// 4. 在 ModelRepo 建立 pending 紀錄(UploadedAt = nil)
|
||||||
|
// 5. 回應 model_id + upload_url
|
||||||
|
//
|
||||||
|
// 錯誤:413 PAYLOAD_TOO_LARGE、400 VALIDATION_FAILED、501(storage/repo 未設)
|
||||||
|
func modelsInitUploadHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.ModelRepo == nil || deps.Storage == nil {
|
||||||
|
WriteNotImplemented(c, "model repo or storage not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ModelInitRequest
|
||||||
|
if err := json.NewDecoder(c.Request.Body).Decode(&req); err != nil {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"invalid JSON: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 驗證 name
|
||||||
|
if strings.TrimSpace(req.Name) == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"name is required", []FieldError{{Field: "name", Message: "cannot be empty"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 驗證 file_size > 0 且不超過上限
|
||||||
|
if req.FileSize <= 0 {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"file_size must be > 0",
|
||||||
|
[]FieldError{{Field: "file_size", Message: "must be positive"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 大小上限檢查(MaxUploadSizeMB 取自 Deps;若為 0 則不限,給測試友善)
|
||||||
|
if deps.MaxUploadSizeMB > 0 {
|
||||||
|
limit := int64(deps.MaxUploadSizeMB) * 1024 * 1024
|
||||||
|
if req.FileSize > limit {
|
||||||
|
WriteError(c, http.StatusRequestEntityTooLarge, ErrCodePayloadTooLarge,
|
||||||
|
"file_size exceeds upload limit",
|
||||||
|
[]FieldError{{Field: "file_size",
|
||||||
|
Message: "max allowed is configured by server"}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
modelID := uuid.NewString()
|
||||||
|
storageKey := "models/" + userID + "/" + modelID + ".nef"
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// 產 presigned PUT URL
|
||||||
|
uploadURL, err := deps.Storage.PresignedPutURL(ctx, storageKey, modelUploadURLTTL)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"presigned url failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 建立 pending 紀錄
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m := &model.Model{
|
||||||
|
ID: modelID,
|
||||||
|
OwnerUserID: userID,
|
||||||
|
Name: req.Name,
|
||||||
|
Description: req.Description,
|
||||||
|
TargetChip: req.TargetChip,
|
||||||
|
FileSize: req.FileSize,
|
||||||
|
FileChecksum: req.Checksum,
|
||||||
|
StorageKey: storageKey,
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
// UploadedAt 保持 nil 直到 finalize
|
||||||
|
}
|
||||||
|
if err := deps.ModelRepo.Save(ctx, m); err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"save model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOrDefault(deps.Logger).Info("models: upload init",
|
||||||
|
"model_id", modelID,
|
||||||
|
"user_id", userID,
|
||||||
|
"file_size", req.FileSize,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, ModelInitResponse{
|
||||||
|
ModelID: modelID,
|
||||||
|
UploadURL: uploadURL,
|
||||||
|
UploadExpiresAt: now.Add(modelUploadURLTTL),
|
||||||
|
StorageKey: storageKey,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelsFinalizeHandler 實作 POST /api/models/:id/finalize。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 取 model(ownership 檢查)
|
||||||
|
// 2. 透過 storage.Stat 驗證檔案已存在
|
||||||
|
// 3. 驗證 Stat().Size == model.FileSize(雛形只做 size 比對;Phase 1 加 checksum)
|
||||||
|
// 4. 更新 UploadedAt;存回 Repo
|
||||||
|
//
|
||||||
|
// 錯誤:
|
||||||
|
// - 檔案還沒 PUT → 400 VALIDATION_FAILED (file not uploaded)
|
||||||
|
// - Size 不符 → 400 VALIDATION_FAILED (size mismatch)
|
||||||
|
func modelsFinalizeHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.ModelRepo == nil || deps.Storage == nil {
|
||||||
|
WriteNotImplemented(c, "model repo or storage not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := c.Param("id")
|
||||||
|
if id == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "model id required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
m, err := deps.ModelRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, model.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "model not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.OwnerUserID != userID {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeForbidden, "not owner", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 驗證檔案已存在
|
||||||
|
obj, statErr := deps.Storage.Stat(ctx, m.StorageKey)
|
||||||
|
if statErr != nil {
|
||||||
|
if errors.Is(statErr, storage.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"file not uploaded yet; PUT to upload_url first", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"stat storage failed: "+statErr.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Size 驗證(雛形只比對 size;Phase 1 加 checksum)
|
||||||
|
if obj.Size != m.FileSize {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"uploaded size mismatch",
|
||||||
|
[]FieldError{
|
||||||
|
{Field: "file_size", Message: "declared vs actual differ"},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 標記 ready
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.UploadedAt = &now
|
||||||
|
m.UpdatedAt = now
|
||||||
|
if err := deps.ModelRepo.Save(ctx, m); err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"save model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOrDefault(deps.Logger).Info("models: upload finalized",
|
||||||
|
"model_id", m.ID,
|
||||||
|
"user_id", userID,
|
||||||
|
"size", m.FileSize,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, toModelResponse(m))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelsDeleteHandler 實作 DELETE /api/models/:id。
|
||||||
|
//
|
||||||
|
// 雛形行為:軟刪(ModelRepo.Delete 已做軟刪)。
|
||||||
|
// 是否一併刪 storage 檔案 — 雛形保留檔案(方便 debug);Phase 1 接 S3 後,
|
||||||
|
// 建議由後台 worker 在 grace period 後刪除(避免使用者誤刪)。
|
||||||
|
func modelsDeleteHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.ModelRepo == nil {
|
||||||
|
WriteNotImplemented(c, "model repo not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
id := c.Param("id")
|
||||||
|
if id == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "model id required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Ownership 檢查
|
||||||
|
m, err := deps.ModelRepo.Get(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, model.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "model not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.OwnerUserID != userID {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeForbidden, "not owner", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deps.ModelRepo.Delete(ctx, id); err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"delete model failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOrDefault(deps.Logger).Info("models: deleted",
|
||||||
|
"model_id", id,
|
||||||
|
"user_id", userID,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
208
visionA-backend/internal/api/models_test.go
Normal file
208
visionA-backend/internal/api/models_test.go
Normal file
@ -0,0 +1,208 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/model"
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 建一個 in-memory fixture(storage + model repo)給 models_test 用。
|
||||||
|
func newModelsFixture(t *testing.T) (*gin.Engine, *model.InMemoryRepository, *storage.LocalFSStore) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
st, err := storage.NewLocalFSStore(dir, "http://api/storage", "test-secret")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
repo := model.NewInMemoryRepository()
|
||||||
|
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
// Phase 0.7 security fix C1:injectStaticUserContext 顯式注入 UserContext。
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerModelRoutes(g, Deps{
|
||||||
|
ModelRepo: repo,
|
||||||
|
Storage: st,
|
||||||
|
MaxUploadSizeMB: 10,
|
||||||
|
})
|
||||||
|
return r, repo, st
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsInit_OK 驗證 init 能成功:建立 pending 紀錄並回 upload_url。
|
||||||
|
func TestModelsInit_OK(t *testing.T) {
|
||||||
|
r, repo, _ := newModelsFixture(t)
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"name":"m1","file_size":1024}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/models/init", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "body=%s", w.Body.String())
|
||||||
|
|
||||||
|
var sb SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &sb))
|
||||||
|
data := sb.Data.(map[string]any)
|
||||||
|
|
||||||
|
modelID, _ := data["model_id"].(string)
|
||||||
|
require.NotEmpty(t, modelID)
|
||||||
|
assert.Contains(t, data["upload_url"].(string), "signature=")
|
||||||
|
|
||||||
|
// Repo 中應已有 pending 紀錄(UploadedAt == nil)
|
||||||
|
m, err := repo.Get(context.Background(), modelID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, m.UploadedAt)
|
||||||
|
assert.Equal(t, int64(1024), m.FileSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsInit_NameMissing 驗證沒 name 回 400。
|
||||||
|
func TestModelsInit_NameMissing(t *testing.T) {
|
||||||
|
r, _, _ := newModelsFixture(t)
|
||||||
|
|
||||||
|
body := strings.NewReader(`{"file_size":1024}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/models/init", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeValidationFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsInit_TooLarge 驗證超過限制回 413。
|
||||||
|
func TestModelsInit_TooLarge(t *testing.T) {
|
||||||
|
r, _, _ := newModelsFixture(t)
|
||||||
|
// MaxUploadSizeMB=10,送 11MB
|
||||||
|
body := strings.NewReader(`{"name":"big","file_size":11534336}`) // 11 MB
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/models/init", body)
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodePayloadTooLarge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsFinalize_FileNotUploaded 驗證 finalize 在沒實際 PUT 前回 400。
|
||||||
|
func TestModelsFinalize_FileNotUploaded(t *testing.T) {
|
||||||
|
r, repo, _ := newModelsFixture(t)
|
||||||
|
|
||||||
|
// 先塞一筆 pending model(沒實際檔案)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m := &model.Model{
|
||||||
|
ID: "mdl-1",
|
||||||
|
OwnerUserID: "demo-user",
|
||||||
|
Name: "x",
|
||||||
|
FileSize: 100,
|
||||||
|
StorageKey: "models/demo-user/mdl-1.nef",
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
require.NoError(t, repo.Save(context.Background(), m))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/mdl-1/finalize", nil))
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "file not uploaded")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsFinalize_SizeMismatch 驗證實際檔案大小對不上 file_size 回 400。
|
||||||
|
func TestModelsFinalize_SizeMismatch(t *testing.T) {
|
||||||
|
r, repo, st := newModelsFixture(t)
|
||||||
|
|
||||||
|
// 塞 pending model(宣稱 100 bytes)
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &model.Model{
|
||||||
|
ID: "mdl-2",
|
||||||
|
OwnerUserID: "demo-user",
|
||||||
|
Name: "x",
|
||||||
|
FileSize: 100,
|
||||||
|
StorageKey: "models/demo-user/mdl-2.nef",
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
}))
|
||||||
|
// 實際檔案寫 10 bytes(Size 不符)
|
||||||
|
require.NoError(t, st.Put(context.Background(), "models/demo-user/mdl-2.nef",
|
||||||
|
strings.NewReader("0123456789"), 10, nil))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/mdl-2/finalize", nil))
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "size mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsFinalize_OK 驗證 happy path:檔案已存在、size 對得上,標 ready。
|
||||||
|
func TestModelsFinalize_OK(t *testing.T) {
|
||||||
|
r, repo, st := newModelsFixture(t)
|
||||||
|
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &model.Model{
|
||||||
|
ID: "mdl-3",
|
||||||
|
OwnerUserID: "demo-user",
|
||||||
|
Name: "x",
|
||||||
|
FileSize: 5,
|
||||||
|
StorageKey: "models/demo-user/mdl-3.nef",
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
}))
|
||||||
|
require.NoError(t, st.Put(context.Background(), "models/demo-user/mdl-3.nef",
|
||||||
|
strings.NewReader("hello"), 5, nil))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/models/mdl-3/finalize", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "body=%s", w.Body.String())
|
||||||
|
|
||||||
|
// Repo 中應已 UploadedAt 被設
|
||||||
|
m, err := repo.Get(context.Background(), "mdl-3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, m.UploadedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsDelete_NotOwner 驗證非 owner 不能刪。
|
||||||
|
func TestModelsDelete_NotOwner(t *testing.T) {
|
||||||
|
r, repo, _ := newModelsFixture(t)
|
||||||
|
|
||||||
|
// 塞一個「別人」的 model
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &model.Model{
|
||||||
|
ID: "mdl-other",
|
||||||
|
OwnerUserID: "other-user",
|
||||||
|
Name: "x",
|
||||||
|
Source: model.SourceUploaded,
|
||||||
|
}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodDelete, "/api/models/mdl-other", nil))
|
||||||
|
assert.Equal(t, http.StatusForbidden, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModelsList_FiltersByOwner 驗證 list 只回當前 user 的模型。
|
||||||
|
func TestModelsList_FiltersByOwner(t *testing.T) {
|
||||||
|
r, repo, _ := newModelsFixture(t)
|
||||||
|
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &model.Model{
|
||||||
|
ID: "my", OwnerUserID: "demo-user", Name: "mine", Source: model.SourceUploaded,
|
||||||
|
}))
|
||||||
|
require.NoError(t, repo.Save(context.Background(), &model.Model{
|
||||||
|
ID: "other", OwnerUserID: "other-user", Name: "theirs", Source: model.SourceUploaded,
|
||||||
|
}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/models", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var sb SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &sb))
|
||||||
|
arr, ok := sb.Data.([]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
assert.Len(t, arr, 1, "只應看到自己的 model")
|
||||||
|
first := arr[0].(map[string]any)
|
||||||
|
assert.Equal(t, "my", first["id"])
|
||||||
|
}
|
||||||
405
visionA-backend/internal/api/oidc_auth.go
Normal file
405
visionA-backend/internal/api/oidc_auth.go
Normal file
@ -0,0 +1,405 @@
|
|||||||
|
// oidc_auth.go — Phase 0.6 BFF OIDC handler 實作。
|
||||||
|
//
|
||||||
|
// 對齊文件:
|
||||||
|
// - oidc-tdd.md §3.1(首次登入流程)
|
||||||
|
// - oidc-tdd.md §3.3(登出)
|
||||||
|
// - oidc-tdd.md §4.5(handler 程式碼範例)
|
||||||
|
// - oidc-tdd.md §6(PKCE)
|
||||||
|
// - oidc-tdd.md §7(id_token 驗證)
|
||||||
|
// - ADR-010(BFF 模式)
|
||||||
|
//
|
||||||
|
// 與既有 auth.go(Static 路徑)並存,由 NewRouter 依 Deps.OIDCEnabled() 決定是否註冊。
|
||||||
|
//
|
||||||
|
// 設計選擇:
|
||||||
|
// - 把 OIDC pending state(state / nonce / code_verifier / return_to)合在
|
||||||
|
// usersession.Session 同一個 cookie 裡。雛形階段 pending 與已登入 session
|
||||||
|
// 共用同一個 store;callback 完成後 pending 欄位清空、寫入 UserID/Email/Name。
|
||||||
|
// 簡化實作、減少 cookie 數量;symmetrically pending 持續時間短(≤ 10 分鐘)。
|
||||||
|
// - 不另外發 visiona_pending_sid cookie(與 oidc-tdd.md §4.5 範例不同 — TDD 是文件示意,
|
||||||
|
// 雛形採取「合一 session」策略;這個權衡記錄於 OB4 任務說明)。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/subtle"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/oidc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// oidcCallbackTimeout 限制 token exchange + id_token verify 的總時間。
|
||||||
|
// 這兩步都有網路 I/O(IdP token endpoint、JWKS 抓取);30s 足以涵蓋 IdP 緩慢回應,
|
||||||
|
// 又不會讓 caller 端等到 default HTTP server timeout。
|
||||||
|
const oidcCallbackTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// MeResponseOIDC 是 OIDC 模式下 GET /api/auth/me 的 data payload。
|
||||||
|
//
|
||||||
|
// 故意與 Legacy MeResponse 區分:OIDC 沒有 Roles 概念(雛形),但有 Name。
|
||||||
|
type MeResponseOIDC struct {
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogoutResponse 是 POST /api/auth/logout 的 data payload。
|
||||||
|
type LogoutResponse struct {
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerOIDCPublicRoutes 註冊「不需登入即可訪問」的 OIDC endpoints。
|
||||||
|
//
|
||||||
|
// 這兩個 endpoint 必須在 AuthMiddleware 之前註冊,否則 user 沒登入根本進不來。
|
||||||
|
//
|
||||||
|
// 路徑刻意與 Legacy /api/auth/* 保持一致 — 因為 OIDC 啟用時 Legacy 的 /api/auth/login
|
||||||
|
// (在 apiGroup 下)會變成「已登入才能呼叫的端點」、且仍會回 501 因為 deps.AuthProvider 通常為 nil。
|
||||||
|
// 實際生效的是這裡註冊的 OIDC 版本。
|
||||||
|
func registerOIDCPublicRoutes(r *gin.Engine, deps Deps) {
|
||||||
|
r.GET("/api/auth/login", oidcLoginHandler(deps))
|
||||||
|
r.GET("/api/auth/callback", oidcCallbackHandler(deps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerOIDCAuthedRoutes 是被 OB4 規劃但實際整合在 registerAuthRoutes(auth.go)裡:
|
||||||
|
// /api/auth/me 和 /api/auth/logout 在 OIDC 模式下需要不同的 handler,
|
||||||
|
// 由 registerAuthRoutes 依 deps.OIDCEnabled() 動態選擇。
|
||||||
|
|
||||||
|
// oidcLoginHandler 實作 GET /api/auth/login(OIDC 模式)。
|
||||||
|
//
|
||||||
|
// 流程(對齊 oidc-tdd.md §3.1 步驟 3):
|
||||||
|
// 1. 解析 return_to query param(白名單檢查避免 open redirect)
|
||||||
|
// 2. 產 PKCE code_verifier / state / nonce(皆 32 byte 隨機)
|
||||||
|
// 3. 透過 SessionManager.StartSession 建立 pending session(含 cookie)
|
||||||
|
// 4. 把 OIDC state 寫入 session 並 Update(讓 callback 能讀到)
|
||||||
|
// 5. 算出 IdP authorize URL(含 state / nonce / code_challenge)
|
||||||
|
// 6. 302 redirect user 到 IdP
|
||||||
|
//
|
||||||
|
// 任何步驟失敗 → 500(沒 session 可清 → 不需 fallback handling)。
|
||||||
|
// 不直接回 JSON 錯誤;redirect 才是這個 endpoint 的合約。失敗時用 WriteError 較直觀。
|
||||||
|
func oidcLoginHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
log := logOrDefault(deps.Logger)
|
||||||
|
|
||||||
|
returnTo := sanitizeReturnTo(c.Query("return_to"))
|
||||||
|
|
||||||
|
verifier, err := oidc.GenerateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oidc.login: generate code verifier failed", "error", err, "request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to start login flow", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state, err := oidc.GenerateState()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oidc.login: generate state failed", "error", err, "request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to start login flow", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
nonce, err := oidc.GenerateNonce()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oidc.login: generate nonce failed", "error", err, "request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to start login flow", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 開新 session(含 cookie)。先 Start 再 Update — Update 會把 OIDC state 寫進 store。
|
||||||
|
sess, err := deps.SessionManager.StartSession(c.Request.Context(), c.Writer)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oidc.login: start session failed", "error", err, "request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to start session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sess.OIDCState = state
|
||||||
|
sess.OIDCNonce = nonce
|
||||||
|
sess.OIDCCodeVerifier = verifier
|
||||||
|
if returnTo != "" {
|
||||||
|
if sess.Extra == nil {
|
||||||
|
sess.Extra = make(map[string]any, 1)
|
||||||
|
}
|
||||||
|
sess.Extra["return_to"] = returnTo
|
||||||
|
}
|
||||||
|
if err := deps.SessionManager.UpdateSession(c.Request.Context(), sess); err != nil {
|
||||||
|
// 清 cookie 避免 user 拿到沒對應 store record 的 zombie cookie
|
||||||
|
_ = deps.SessionManager.EndSession(c.Request.Context(), c.Writer, c.Request)
|
||||||
|
log.Error("oidc.login: update pending session failed", "error", err, "request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to persist pending session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
challenge := oidc.CodeChallenge(verifier)
|
||||||
|
authURL := deps.OIDCProvider.AuthorizationURL(state, nonce, challenge)
|
||||||
|
|
||||||
|
log.Info("oidc.login: redirecting to IdP",
|
||||||
|
"request_id", RequestIDFrom(c),
|
||||||
|
"action", "oidc.login.redirect",
|
||||||
|
"return_to", returnTo,
|
||||||
|
)
|
||||||
|
c.Redirect(http.StatusFound, authURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// oidcCallbackHandler 實作 GET /api/auth/callback(OIDC 模式)。
|
||||||
|
//
|
||||||
|
// 對齊 oidc-tdd.md §3.1 步驟 9-12 / §4.5:
|
||||||
|
// 1. 處理 IdP error response(user 取消、IdP 錯誤)
|
||||||
|
// 2. 從 cookie 拿 pending session
|
||||||
|
// 3. 比對 state(CSRF 防護)
|
||||||
|
// 4. ExchangeCode(PKCE)
|
||||||
|
// 5. VerifyIDToken(驗簽 + nonce)
|
||||||
|
// 6. RotateSessionID(Fix-A1:session fixation 防護,OWASP ASVS V3.2.1)
|
||||||
|
// 7. 把 claims 寫入新 session(UserID / Email / Name),清 OIDC pending state,清 return_to
|
||||||
|
// 8. UpdateSession(LastSeenAt 自動刷新)
|
||||||
|
// 9. 302 回 frontend 的 PostLoginURL + return_to
|
||||||
|
//
|
||||||
|
// 失敗一律回 JSON 錯誤(4xx / 5xx);callback 是「夾在中間」的 endpoint,
|
||||||
|
// 直接 redirect user 到 frontend 的 error 頁也是選項,但雛形先回 JSON 便於測試。
|
||||||
|
func oidcCallbackHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
log := logOrDefault(deps.Logger)
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), oidcCallbackTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// IdP 錯誤回應(OAuth 2.0 §4.1.2.1):user 拒絕授權、IdP 內部錯誤等
|
||||||
|
if errCode := c.Query("error"); errCode != "" {
|
||||||
|
errDesc := c.Query("error_description")
|
||||||
|
log.Warn("oidc.callback: IdP returned error",
|
||||||
|
"request_id", RequestIDFrom(c),
|
||||||
|
"error_code", errCode,
|
||||||
|
"error_description", errDesc,
|
||||||
|
)
|
||||||
|
// 清掉 pending session(即使存在),確保 cookie 不會殘留
|
||||||
|
_ = deps.SessionManager.EndSession(ctx, c.Writer, c.Request)
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeUnauthorized,
|
||||||
|
"identity provider returned error: "+errCode, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := c.Query("code")
|
||||||
|
state := c.Query("state")
|
||||||
|
if code == "" || state == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"missing code or state query parameter", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 從 cookie 取 pending session
|
||||||
|
sess, err := deps.SessionManager.GetSession(ctx, c.Request)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("oidc.callback: pending session not found",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeUnauthorized, "no pending session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 驗 state(CSRF 防護)— 用常數時間比對避免 timing attack
|
||||||
|
if subtle.ConstantTimeCompare([]byte(sess.OIDCState), []byte(state)) != 1 {
|
||||||
|
log.Warn("oidc.callback: state mismatch",
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
// state 不對 → 視為攻擊嘗試或過期 session,刪掉重來
|
||||||
|
_ = deps.SessionManager.EndSession(ctx, c.Writer, c.Request)
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeUnauthorized, "state mismatch", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 換 token
|
||||||
|
tok, err := deps.OIDCProvider.ExchangeCode(ctx, code, sess.OIDCCodeVerifier)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("oidc.callback: token exchange failed",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
status := http.StatusBadGateway
|
||||||
|
if errors.Is(err, oidc.ErrInvalidGrant) {
|
||||||
|
status = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
WriteError(c, status, ErrCodeUnauthorized, "token exchange failed", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 驗 id_token(含 nonce 比對)
|
||||||
|
claims, err := deps.OIDCProvider.VerifyIDToken(ctx, tok.IDToken, sess.OIDCNonce)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn("oidc.callback: id_token verification failed",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
WriteError(c, http.StatusUnauthorized, ErrCodeUnauthorized, "id_token verification failed", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session fixation 防護(OWASP ASVS V3.2.1)— Fix-A1 / Major-1。
|
||||||
|
//
|
||||||
|
// 在「驗 id_token 成功後、寫使用者 info 進 session 之前」rotate session ID。
|
||||||
|
// 這樣攻擊者預先誘騙受害者使用的 pending cookie 在這一刻失效,
|
||||||
|
// 即使攻擊者持有舊 cookie 也無法接續成「已登入」狀態。
|
||||||
|
//
|
||||||
|
// rotate 失敗 → 不能讓登入完成(fail-closed)。清掉舊 cookie,回 500。
|
||||||
|
newSess, err := deps.SessionManager.RotateSessionID(ctx, c.Writer, c.Request)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("oidc.callback: session rotation failed",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
// 把舊 session 也清掉,避免 stale pending session 留著。
|
||||||
|
_ = deps.SessionManager.EndSession(ctx, c.Writer, c.Request)
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to rotate session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 後續所有 session 操作都用 newSess(舊的已不可達)。
|
||||||
|
sess = newSess
|
||||||
|
|
||||||
|
// 寫 session(清 pending state,填 user info)
|
||||||
|
sess.UserID = claims.Subject
|
||||||
|
sess.Email = claims.Email
|
||||||
|
sess.Name = claims.Name
|
||||||
|
// 雛形 access_token / id_token raw 仍保留在 session(未來 RP-initiated logout 用)。
|
||||||
|
// 注意:絕對不可進入 log(oidc-tdd.md §14.5)。
|
||||||
|
sess.AccessToken = tok.AccessToken
|
||||||
|
sess.IDTokenRaw = tok.IDToken
|
||||||
|
// 清掉 OIDC pending state
|
||||||
|
sess.OIDCState = ""
|
||||||
|
sess.OIDCNonce = ""
|
||||||
|
sess.OIDCCodeVerifier = ""
|
||||||
|
|
||||||
|
// 取 return_to(在 login handler 寫入 sess.Extra;經 rotation 後仍保留)
|
||||||
|
returnTo := "/"
|
||||||
|
if v, ok := sess.Extra["return_to"]; ok {
|
||||||
|
if s, ok := v.(string); ok && s != "" {
|
||||||
|
returnTo = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 把 return_to 清理併入同一次 UpdateSession(Major-4 修復:避免吞錯誤的二次 Update)。
|
||||||
|
// 之前是先 UpdateSession 寫 user info、再 UpdateSession 清 return_to 並 _ = err 吞錯誤;
|
||||||
|
// 現在合一:清 Extra → 一次 UpdateSession 把 user info + return_to 清理同時 commit。
|
||||||
|
if sess.Extra != nil {
|
||||||
|
delete(sess.Extra, "return_to")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deps.SessionManager.UpdateSession(ctx, sess); err != nil {
|
||||||
|
log.Error("oidc.callback: update session failed",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError, "failed to persist session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 算 redirect URL:PostLoginURL + return_to。
|
||||||
|
//
|
||||||
|
// 用 url.Parse + ResolveReference 而非字串拼接:
|
||||||
|
// - 字串拼接會在 PostLoginURL 帶 trailing slash + returnTo 帶 leading slash
|
||||||
|
// 時產生 "//",被瀏覽器當 protocol-relative URL 跳到外部站。
|
||||||
|
// - ResolveReference 正確處理 trailing slash、保留 query / fragment、
|
||||||
|
// 且若 returnTo 不慎含 scheme/host(理論上 sanitizeReturnTo 已擋)會
|
||||||
|
// 被當成絕對 URL 取代 base — 我們再用 SameHost 檢查防禦性兜底。
|
||||||
|
//
|
||||||
|
// returnTo 已經 sanitizeReturnTo("/" 開頭、無 "//"、無 "://"),這裡是雙重防護。
|
||||||
|
redirectURL := returnTo
|
||||||
|
if deps.OIDCPostLoginURL != "" {
|
||||||
|
base, baseErr := url.Parse(deps.OIDCPostLoginURL)
|
||||||
|
ref, refErr := url.Parse(returnTo)
|
||||||
|
if baseErr != nil || refErr != nil || base.Host == "" {
|
||||||
|
// PostLoginURL / returnTo 不是合法 URL — 退回 same-origin。
|
||||||
|
log.Warn("oidc.callback: parse redirect base/ref failed, falling back to same-origin",
|
||||||
|
"request_id", RequestIDFrom(c), "base_err", baseErr, "ref_err", refErr)
|
||||||
|
redirectURL = returnTo
|
||||||
|
} else {
|
||||||
|
resolved := base.ResolveReference(ref)
|
||||||
|
// 防禦性檢查:resolve 後 host 必須仍等於 base.Host(避免 returnTo 偷渡 host)。
|
||||||
|
if resolved.Host != base.Host || resolved.Scheme != base.Scheme {
|
||||||
|
log.Warn("oidc.callback: resolved redirect host/scheme mismatch, falling back",
|
||||||
|
"request_id", RequestIDFrom(c),
|
||||||
|
"base_host", base.Host, "resolved_host", resolved.Host)
|
||||||
|
redirectURL = returnTo
|
||||||
|
} else {
|
||||||
|
redirectURL = resolved.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("oidc.callback: login success",
|
||||||
|
"request_id", RequestIDFrom(c),
|
||||||
|
"action", "oidc.callback.success",
|
||||||
|
"user_id", claims.Subject,
|
||||||
|
)
|
||||||
|
c.Redirect(http.StatusFound, redirectURL)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// oidcLogoutHandler 實作 POST /api/auth/logout(OIDC 模式)。
|
||||||
|
//
|
||||||
|
// 雛形不做 RP-initiated logout(不通知 IdP)— 只清本地 session + cookie。
|
||||||
|
// Idempotent:cookie 不存在或 session 已清也回 200。
|
||||||
|
//
|
||||||
|
// 對齊 oidc-tdd.md §3.3。
|
||||||
|
func oidcLogoutHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
log := logOrDefault(deps.Logger)
|
||||||
|
var userID string
|
||||||
|
if uc, ok := UserContextFrom(c); ok {
|
||||||
|
userID = uc.UserID
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := deps.SessionManager.EndSession(ctx, c.Writer, c.Request); err != nil {
|
||||||
|
// EndSession 內部已清 cookie;只 log 不 fail(保持 idempotent)
|
||||||
|
log.Warn("oidc.logout: end session reported error",
|
||||||
|
"request_id", RequestIDFrom(c), "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("oidc.logout",
|
||||||
|
"request_id", RequestIDFrom(c),
|
||||||
|
"action", "oidc.logout",
|
||||||
|
"user_id", userID,
|
||||||
|
)
|
||||||
|
WriteSuccess(c, http.StatusOK, LogoutResponse{Success: true})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// oidcMeHandler 實作 GET /api/auth/me(OIDC 模式)。
|
||||||
|
//
|
||||||
|
// 主要從 AuthMiddleware 注入的 UserContext / Session 取資料 — 不再呼叫 store。
|
||||||
|
// 對齊 oidc-tdd.md §4.5 Me 範例。
|
||||||
|
func oidcMeHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc == nil {
|
||||||
|
WriteError(c, http.StatusUnauthorized, ErrCodeUnauthorized, "not authenticated", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Session 含 Name;UserContext 沒有,所以從 session 拿
|
||||||
|
var name string
|
||||||
|
if sess, ok := UserSessionFrom(c); ok && sess != nil {
|
||||||
|
name = sess.Name
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, MeResponseOIDC{
|
||||||
|
UserID: uc.UserID,
|
||||||
|
Email: uc.Email,
|
||||||
|
Name: name,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeReturnTo 防止 open redirect 攻擊。
|
||||||
|
//
|
||||||
|
// 規則:
|
||||||
|
// - 必須以 "/" 開頭(同 origin path)
|
||||||
|
// - 不能以 "//" 開頭(protocol-relative URL,會跳到攻擊者站)
|
||||||
|
// - 不能含 "://" 或 "\"(避免各種 URL parsing trick)
|
||||||
|
//
|
||||||
|
// 不合規回空字串(caller 視為「沒指定」,會走預設 "/")。
|
||||||
|
func sanitizeReturnTo(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(raw, "/") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(raw, "//") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.Contains(raw, "://") || strings.Contains(raw, "\\") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
|
||||||
604
visionA-backend/internal/api/oidc_auth_test.go
Normal file
604
visionA-backend/internal/api/oidc_auth_test.go
Normal file
@ -0,0 +1,604 @@
|
|||||||
|
// oidc_auth_test.go — OIDC handler 與 OIDC-mode middleware 的 unit test。
|
||||||
|
//
|
||||||
|
// 設計策略:用 mockOIDCProvider 取代真實 IdP(避免 IO、純 Go function call)。
|
||||||
|
// 這樣測試快且確定性高;真實 IdP 整合留給 OT1(fake server)+ OT2(end-to-end)。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/oidc"
|
||||||
|
"visiona-backend/internal/usersession"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---- mockOIDCProvider ---------------------------------------------------
|
||||||
|
|
||||||
|
// mockOIDCProvider 實作 oidc.Provider,回傳由 test 預先設定的固定值。
|
||||||
|
//
|
||||||
|
// 比 fake HTTP server 簡單很多:直接控制每個方法的回傳,可注入錯誤情境。
|
||||||
|
type mockOIDCProvider struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
// AuthorizationURL 行為控制
|
||||||
|
authURLBase string // 預設 "https://idp.example/authorize"
|
||||||
|
|
||||||
|
// ExchangeCode 行為控制
|
||||||
|
exchangeFn func(ctx context.Context, code, verifier string) (*oidc.TokenResponse, error)
|
||||||
|
|
||||||
|
// VerifyIDToken 行為控制
|
||||||
|
verifyFn func(ctx context.Context, raw, expectedNonce string) (*oidc.Claims, error)
|
||||||
|
|
||||||
|
// 記錄呼叫參數供 test assertion 用
|
||||||
|
gotCode string
|
||||||
|
gotVerifier string
|
||||||
|
gotIDToken string
|
||||||
|
gotNonce string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOIDCProvider) AuthorizationURL(state, nonce, codeChallenge string) string {
|
||||||
|
base := m.authURLBase
|
||||||
|
if base == "" {
|
||||||
|
base = "https://idp.example/authorize"
|
||||||
|
}
|
||||||
|
q := url.Values{}
|
||||||
|
q.Set("state", state)
|
||||||
|
q.Set("nonce", nonce)
|
||||||
|
q.Set("code_challenge", codeChallenge)
|
||||||
|
q.Set("code_challenge_method", "S256")
|
||||||
|
return base + "?" + q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOIDCProvider) ExchangeCode(ctx context.Context, code, verifier string) (*oidc.TokenResponse, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.gotCode = code
|
||||||
|
m.gotVerifier = verifier
|
||||||
|
m.mu.Unlock()
|
||||||
|
if m.exchangeFn != nil {
|
||||||
|
return m.exchangeFn(ctx, code, verifier)
|
||||||
|
}
|
||||||
|
return &oidc.TokenResponse{
|
||||||
|
AccessToken: "access-token-xyz",
|
||||||
|
IDToken: "id-token-xyz",
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockOIDCProvider) VerifyIDToken(ctx context.Context, raw, expectedNonce string) (*oidc.Claims, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.gotIDToken = raw
|
||||||
|
m.gotNonce = expectedNonce
|
||||||
|
m.mu.Unlock()
|
||||||
|
if m.verifyFn != nil {
|
||||||
|
return m.verifyFn(ctx, raw, expectedNonce)
|
||||||
|
}
|
||||||
|
return &oidc.Claims{
|
||||||
|
Subject: "user-123",
|
||||||
|
Email: "alice@example.com",
|
||||||
|
Name: "Alice",
|
||||||
|
Nonce: expectedNonce,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- helper: 建立啟用 OIDC 的測試 Deps + router ---------------------------
|
||||||
|
|
||||||
|
func newOIDCTestDeps(provider *mockOIDCProvider) Deps {
|
||||||
|
mgr := usersession.NewManager(usersession.NewInMemoryStore(), usersession.CookieConfig{
|
||||||
|
Name: "visiona_session",
|
||||||
|
Path: "/",
|
||||||
|
HTTPOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: 86400,
|
||||||
|
SigningKey: []byte("test-secret-32-byte-key-aaaaaaaaaaaa"),
|
||||||
|
})
|
||||||
|
return Deps{
|
||||||
|
OIDCProvider: provider,
|
||||||
|
SessionManager: mgr,
|
||||||
|
OIDCPostLoginURL: "http://localhost:3000",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// newOIDCRouter 建立完整 router(含 public + apiGroup AuthMiddleware)。
|
||||||
|
func newOIDCRouter(deps Deps) *gin.Engine {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
// public OIDC routes(必須在 AuthMiddleware 之外)
|
||||||
|
registerOIDCPublicRoutes(r, deps)
|
||||||
|
// /api 群組(含 AuthMiddleware + auth handlers)
|
||||||
|
g := r.Group("/api")
|
||||||
|
g.Use(AuthMiddleware(deps))
|
||||||
|
registerAuthRoutes(g, deps)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- TESTS: oidcLoginHandler --------------------------------------------
|
||||||
|
|
||||||
|
// TestOIDCLogin_RedirectsToIdPWithProperParams 驗證 /api/auth/login 會 302 到 IdP,
|
||||||
|
// 並設好 cookie + 在 session 中存好 PKCE state / nonce / verifier。
|
||||||
|
func TestOIDCLogin_RedirectsToIdPWithProperParams(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/auth/login?return_to=/dashboard", nil)
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusFound, w.Code)
|
||||||
|
loc := w.Header().Get("Location")
|
||||||
|
require.NotEmpty(t, loc)
|
||||||
|
parsed, err := url.Parse(loc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 應該有 state / nonce / code_challenge 三個 query
|
||||||
|
q := parsed.Query()
|
||||||
|
assert.NotEmpty(t, q.Get("state"))
|
||||||
|
assert.NotEmpty(t, q.Get("nonce"))
|
||||||
|
assert.NotEmpty(t, q.Get("code_challenge"))
|
||||||
|
assert.Equal(t, "S256", q.Get("code_challenge_method"))
|
||||||
|
|
||||||
|
// 應該有 Set-Cookie
|
||||||
|
cookies := w.Result().Cookies()
|
||||||
|
require.NotEmpty(t, cookies, "expected Set-Cookie")
|
||||||
|
var sessCookie *http.Cookie
|
||||||
|
for _, c := range cookies {
|
||||||
|
if c.Name == "visiona_session" {
|
||||||
|
sessCookie = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, sessCookie, "expected visiona_session cookie")
|
||||||
|
assert.True(t, sessCookie.HttpOnly)
|
||||||
|
assert.Equal(t, http.SameSiteLaxMode, sessCookie.SameSite)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCLogin_SanitizesReturnTo 驗證 open redirect 防護。
|
||||||
|
func TestOIDCLogin_SanitizesReturnTo(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
raw string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"normal_path", "/dashboard", "/dashboard"},
|
||||||
|
{"path_with_query", "/devices?x=1", "/devices?x=1"},
|
||||||
|
{"absolute_url", "http://evil.example/", ""},
|
||||||
|
{"protocol_relative", "//evil.example", ""},
|
||||||
|
{"backslash_trick", "/\\evil.example", ""},
|
||||||
|
{"missing_leading_slash", "evil", ""},
|
||||||
|
{"scheme_in_path", "/foo://bar", ""},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, sanitizeReturnTo(tc.raw))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- TESTS: oidcCallbackHandler -----------------------------------------
|
||||||
|
|
||||||
|
// TestOIDCCallback_HappyPath 驗證 callback 完整流程跑通。
|
||||||
|
func TestOIDCCallback_HappyPath(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
// Step 1: 觸發 login,拿到 state + cookie
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
loginReq := httptest.NewRequest(http.MethodGet, "/api/auth/login?return_to=/devices", nil)
|
||||||
|
r.ServeHTTP(loginW, loginReq)
|
||||||
|
require.Equal(t, http.StatusFound, loginW.Code)
|
||||||
|
|
||||||
|
loc, _ := url.Parse(loginW.Header().Get("Location"))
|
||||||
|
state := loc.Query().Get("state")
|
||||||
|
require.NotEmpty(t, state)
|
||||||
|
|
||||||
|
// 提取 cookie
|
||||||
|
cookies := loginW.Result().Cookies()
|
||||||
|
require.NotEmpty(t, cookies)
|
||||||
|
|
||||||
|
// Step 2: 模擬 IdP 302 回 callback(帶上 cookie + state)
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=auth-code-xyz&state="+url.QueryEscape(state), nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
|
||||||
|
// 預期 302 回 frontend
|
||||||
|
require.Equal(t, http.StatusFound, cbW.Code, "body=%s", cbW.Body.String())
|
||||||
|
redirect := cbW.Header().Get("Location")
|
||||||
|
assert.Equal(t, "http://localhost:3000/devices", redirect)
|
||||||
|
|
||||||
|
// 驗 mock 收到正確的 code
|
||||||
|
assert.Equal(t, "auth-code-xyz", provider.gotCode)
|
||||||
|
assert.NotEmpty(t, provider.gotVerifier, "verifier should be passed to ExchangeCode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_StateMismatch 驗證 state 不符回 400 並清 session。
|
||||||
|
func TestOIDCCallback_StateMismatch(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
// 先 login 拿 cookie
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login", nil))
|
||||||
|
cookies := loginW.Result().Cookies()
|
||||||
|
|
||||||
|
// Callback 帶錯的 state
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=xyz&state=wrong-state", nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "state mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_NoCookie 驗證沒帶 cookie → 400。
|
||||||
|
func TestOIDCCallback_NoCookie(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet, "/api/auth/callback?code=xyz&state=abc", nil)
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "no pending session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_MissingCodeOrState 驗證 missing query 回 400。
|
||||||
|
func TestOIDCCallback_MissingCodeOrState(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(cbW, httptest.NewRequest(http.MethodGet, "/api/auth/callback", nil))
|
||||||
|
assert.Equal(t, http.StatusBadRequest, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "missing code or state")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_IdPError 驗證 IdP 回 error param → 400。
|
||||||
|
func TestOIDCCallback_IdPError(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(cbW, httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?error=access_denied&error_description=user_cancelled", nil))
|
||||||
|
assert.Equal(t, http.StatusBadRequest, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "access_denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_TokenExchangeInvalidGrant 驗證 invalid_grant → 400。
|
||||||
|
func TestOIDCCallback_TokenExchangeInvalidGrant(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{
|
||||||
|
exchangeFn: func(ctx context.Context, code, verifier string) (*oidc.TokenResponse, error) {
|
||||||
|
return nil, oidc.ErrInvalidGrant
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login", nil))
|
||||||
|
state := mustExtractStateFromLoginRedirect(t, loginW)
|
||||||
|
cookies := loginW.Result().Cookies()
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=xyz&state="+url.QueryEscape(state), nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
assert.Equal(t, http.StatusBadRequest, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "token exchange failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_RotatesSessionID_PreventsFixation 驗證 Fix-A1(session fixation 防護):
|
||||||
|
//
|
||||||
|
// 攻擊情境:攻擊者預先取得一個 pending session cookie(自己跑 /api/auth/login),
|
||||||
|
// 誘騙受害者使用此 cookie 走完 OIDC flow。
|
||||||
|
//
|
||||||
|
// 防護驗證:
|
||||||
|
// - callback 完成時必須 rotate cookie value(瀏覽器收到的 Set-Cookie 與原 cookie value 不同)
|
||||||
|
// - 用「攻擊者持有的舊 cookie」訪 /api/auth/me 應該 401(pending session 已不存在於 store)
|
||||||
|
// - 用「callback 回傳的新 cookie」訪 /api/auth/me 應該 200(已登入)
|
||||||
|
func TestOIDCCallback_RotatesSessionID_PreventsFixation(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
// 模擬「攻擊者預先取得 pending cookie」
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login", nil))
|
||||||
|
state := mustExtractStateFromLoginRedirect(t, loginW)
|
||||||
|
attackerCookies := loginW.Result().Cookies()
|
||||||
|
require.NotEmpty(t, attackerCookies)
|
||||||
|
attackerCookieValue := attackerCookies[0].Value
|
||||||
|
|
||||||
|
// 模擬「受害者用攻擊者的 cookie 走完 callback」
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
for _, c := range attackerCookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
require.Equal(t, http.StatusFound, cbW.Code, "callback should succeed; body=%s", cbW.Body.String())
|
||||||
|
|
||||||
|
// 驗證 1:callback 必須寫一個新 cookie,且 value 與舊 cookie 不同
|
||||||
|
newCookies := cbW.Result().Cookies()
|
||||||
|
require.NotEmpty(t, newCookies, "callback must write new Set-Cookie (rotation)")
|
||||||
|
var newSessCookie *http.Cookie
|
||||||
|
for _, c := range newCookies {
|
||||||
|
if c.Name == "visiona_session" {
|
||||||
|
newSessCookie = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, newSessCookie, "expected visiona_session cookie after callback")
|
||||||
|
assert.NotEqual(t, attackerCookieValue, newSessCookie.Value,
|
||||||
|
"session fixation: cookie value MUST change after login (rotate session ID)")
|
||||||
|
|
||||||
|
// 驗證 2:用攻擊者持有的舊 cookie 訪 /me → 401(攻擊者拿不到 victim 帳號)
|
||||||
|
attackerMeW := httptest.NewRecorder()
|
||||||
|
attackerMeReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||||
|
for _, c := range attackerCookies {
|
||||||
|
attackerMeReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(attackerMeW, attackerMeReq)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, attackerMeW.Code,
|
||||||
|
"attacker's old cookie must be rejected after rotation; body=%s", attackerMeW.Body.String())
|
||||||
|
|
||||||
|
// 驗證 3:用受害者的新 cookie 訪 /me → 200(合法)
|
||||||
|
victimMeW := httptest.NewRecorder()
|
||||||
|
victimMeReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||||
|
for _, c := range newCookies {
|
||||||
|
victimMeReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(victimMeW, victimMeReq)
|
||||||
|
assert.Equal(t, http.StatusOK, victimMeW.Code,
|
||||||
|
"victim's new cookie must be accepted; body=%s", victimMeW.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCCallback_VerifyFails 驗證 id_token 驗證失敗 → 401。
|
||||||
|
func TestOIDCCallback_VerifyFails(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{
|
||||||
|
verifyFn: func(ctx context.Context, raw, nonce string) (*oidc.Claims, error) {
|
||||||
|
return nil, oidc.ErrInvalidIDToken
|
||||||
|
},
|
||||||
|
}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login", nil))
|
||||||
|
state := mustExtractStateFromLoginRedirect(t, loginW)
|
||||||
|
cookies := loginW.Result().Cookies()
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=xyz&state="+url.QueryEscape(state), nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, cbW.Code)
|
||||||
|
assert.Contains(t, cbW.Body.String(), "id_token verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- TESTS: AuthMiddleware (OIDC 模式) + /api/auth/me + /api/auth/logout ----
|
||||||
|
|
||||||
|
// TestOIDCMiddleware_Allows_AuthenticatedSession 驗證已登入 session 通過 + me 回 user info。
|
||||||
|
func TestOIDCMiddleware_Allows_AuthenticatedSession(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
// 完整跑一次 login + callback 拿到登入 session
|
||||||
|
cookies := loginAndCallback(t, r, deps, provider)
|
||||||
|
|
||||||
|
// 訪 /api/auth/me — 應 200 + 帶 user info
|
||||||
|
meW := httptest.NewRecorder()
|
||||||
|
meReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
meReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(meW, meReq)
|
||||||
|
require.Equal(t, http.StatusOK, meW.Code, "body=%s", meW.Body.String())
|
||||||
|
|
||||||
|
var sb SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(meW.Body.Bytes(), &sb))
|
||||||
|
data := sb.Data.(map[string]any)
|
||||||
|
assert.Equal(t, "user-123", data["user_id"])
|
||||||
|
assert.Equal(t, "alice@example.com", data["email"])
|
||||||
|
assert.Equal(t, "Alice", data["name"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCMiddleware_Rejects_NoCookie 驗證沒 cookie → 401。
|
||||||
|
func TestOIDCMiddleware_Rejects_NoCookie(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/auth/me", nil))
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "no_session")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCMiddleware_Rejects_PendingSession 驗證 pending session(UserID 空)→ 401。
|
||||||
|
//
|
||||||
|
// 情境:使用者啟動 login 但還沒走完 callback,只有 pending session cookie
|
||||||
|
// 就直接訪 /api/auth/me — 應該被拒絕,而不是被當已登入。
|
||||||
|
func TestOIDCMiddleware_Rejects_PendingSession(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login", nil))
|
||||||
|
cookies := loginW.Result().Cookies()
|
||||||
|
|
||||||
|
meW := httptest.NewRecorder()
|
||||||
|
meReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
meReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(meW, meReq)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, meW.Code)
|
||||||
|
assert.Contains(t, meW.Body.String(), "session_not_authenticated")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDCLogout_ClearsSession 驗證 logout 200 + 清 cookie。
|
||||||
|
func TestOIDCLogout_ClearsSession(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
cookies := loginAndCallback(t, r, deps, provider)
|
||||||
|
|
||||||
|
// POST /api/auth/logout
|
||||||
|
logoutW := httptest.NewRecorder()
|
||||||
|
logoutReq := httptest.NewRequest(http.MethodPost, "/api/auth/logout", nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
logoutReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(logoutW, logoutReq)
|
||||||
|
assert.Equal(t, http.StatusOK, logoutW.Code)
|
||||||
|
|
||||||
|
// Set-Cookie 應該帶過期 attribute
|
||||||
|
respCookies := logoutW.Result().Cookies()
|
||||||
|
var cleared *http.Cookie
|
||||||
|
for _, c := range respCookies {
|
||||||
|
if c.Name == "visiona_session" {
|
||||||
|
cleared = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NotNil(t, cleared, "expected visiona_session clearing cookie")
|
||||||
|
assert.True(t, cleared.MaxAge < 0, "expected MaxAge < 0 to clear cookie")
|
||||||
|
|
||||||
|
// 之後 /api/auth/me 應該 401
|
||||||
|
meW := httptest.NewRecorder()
|
||||||
|
meReq := httptest.NewRequest(http.MethodGet, "/api/auth/me", nil)
|
||||||
|
for _, c := range cookies {
|
||||||
|
meReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(meW, meReq)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, meW.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOIDC_LegacyLogin_Returns410 驗證 OIDC 模式下 POST /api/auth/login 回 410。
|
||||||
|
func TestOIDC_LegacyLogin_Returns410(t *testing.T) {
|
||||||
|
provider := &mockOIDCProvider{}
|
||||||
|
deps := newOIDCTestDeps(provider)
|
||||||
|
r := newOIDCRouter(deps)
|
||||||
|
|
||||||
|
// POST /api/auth/login 在 OIDC 模式下不支援 — 但會先過 AuthMiddleware
|
||||||
|
// (沒帶 cookie 就 401)。為了測 410 行為,先登入拿 cookie。
|
||||||
|
cookies := loginAndCallback(t, r, deps, provider)
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/auth/login", strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
for _, c := range cookies {
|
||||||
|
req.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusGone, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "GET /api/auth/login")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewRouterValidate_PanicsWithoutOIDC 驗證 OB5 起 NewRouter 在缺 OIDC 依賴時 panic。
|
||||||
|
func TestNewRouterValidate_PanicsWithoutOIDC(t *testing.T) {
|
||||||
|
t.Run("no provider, no manager", func(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
(&Deps{}).validate()
|
||||||
|
}, "缺兩個 OIDC 依賴應 panic")
|
||||||
|
})
|
||||||
|
t.Run("only provider", func(t *testing.T) {
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
(&Deps{OIDCProvider: &mockOIDCProvider{}}).validate()
|
||||||
|
}, "缺 SessionManager 應 panic")
|
||||||
|
})
|
||||||
|
t.Run("only manager", func(t *testing.T) {
|
||||||
|
// 這裡的 SigningKey 長度必須 ≥ 32 bytes(usersession.MinSigningKeyBytes),否則 NewManager 會 panic。
|
||||||
|
d := &Deps{SessionManager: usersession.NewManager(usersession.NewInMemoryStore(), usersession.CookieConfig{SigningKey: []byte("test-key-test-key-test-key-1234!")})}
|
||||||
|
assert.Panics(t, func() {
|
||||||
|
d.validate()
|
||||||
|
}, "缺 OIDCProvider 應 panic")
|
||||||
|
})
|
||||||
|
t.Run("both set passes", func(t *testing.T) {
|
||||||
|
d := newOIDCTestDeps(&mockOIDCProvider{})
|
||||||
|
assert.NotPanics(t, func() {
|
||||||
|
d.validate()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- helper: 共用的 login + callback 流程 --------------------------------
|
||||||
|
|
||||||
|
// loginAndCallback 跑完整 login → callback 流程,回傳已登入 session 的 cookie。
|
||||||
|
//
|
||||||
|
// Fix-A1(session fixation 防護)後:callback 完成時會 rotate session ID,cookie 會被改寫。
|
||||||
|
// 因此優先回傳 callback 後的 Set-Cookie;若 callback 沒寫新 cookie(理論上不應該)
|
||||||
|
// 才 fallback 用 login 階段的 cookie。
|
||||||
|
func loginAndCallback(t *testing.T, r *gin.Engine, deps Deps, _ *mockOIDCProvider) []*http.Cookie {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
loginW := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(loginW, httptest.NewRequest(http.MethodGet, "/api/auth/login?return_to=/", nil))
|
||||||
|
require.Equal(t, http.StatusFound, loginW.Code)
|
||||||
|
state := mustExtractStateFromLoginRedirect(t, loginW)
|
||||||
|
loginCookies := loginW.Result().Cookies()
|
||||||
|
|
||||||
|
cbW := httptest.NewRecorder()
|
||||||
|
cbReq := httptest.NewRequest(http.MethodGet,
|
||||||
|
"/api/auth/callback?code=auth-code&state="+url.QueryEscape(state), nil)
|
||||||
|
for _, c := range loginCookies {
|
||||||
|
cbReq.AddCookie(c)
|
||||||
|
}
|
||||||
|
r.ServeHTTP(cbW, cbReq)
|
||||||
|
require.Equal(t, http.StatusFound, cbW.Code, "callback failed: %s", cbW.Body.String())
|
||||||
|
|
||||||
|
// callback 完成後的 Set-Cookie 是 rotation 後的新 cookie;用它做後續請求。
|
||||||
|
cbCookies := cbW.Result().Cookies()
|
||||||
|
if len(cbCookies) > 0 {
|
||||||
|
return cbCookies
|
||||||
|
}
|
||||||
|
return loginCookies
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustExtractStateFromLoginRedirect 從 login redirect 的 Location 取出 state。
|
||||||
|
func mustExtractStateFromLoginRedirect(t *testing.T, w *httptest.ResponseRecorder) string {
|
||||||
|
t.Helper()
|
||||||
|
loc, err := url.Parse(w.Header().Get("Location"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
state := loc.Query().Get("state")
|
||||||
|
require.NotEmpty(t, state)
|
||||||
|
return state
|
||||||
|
}
|
||||||
516
visionA-backend/internal/api/pairing.go
Normal file
516
visionA-backend/internal/api/pairing.go
Normal file
@ -0,0 +1,516 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pairingTokenTTL 是新發 pairing token 的存活時間。
|
||||||
|
//
|
||||||
|
// 對齊 security.md §1.3 的「短期一次性 token」設計:15 分鐘足夠完成
|
||||||
|
// 「使用者啟動 local-tool → 點配對按鈕 → token 進來」的流程。
|
||||||
|
const pairingTokenTTL = 15 * time.Minute
|
||||||
|
|
||||||
|
// registerPairingRoutes 註冊 /api/pairing/* 的 routes。
|
||||||
|
//
|
||||||
|
// MVP 全集合(B4 + B5):
|
||||||
|
// - POST /api/pairing/token → 建立 + 回傳 pairing token
|
||||||
|
// - GET /api/pairing/status → 回傳當前 user 的 tunnel 連線狀態
|
||||||
|
// - GET /api/pairing/tokens → 列當前 user 的所有 token(B5)
|
||||||
|
// - DELETE /api/pairing/tokens/:token → 撤銷指定 token(B5)
|
||||||
|
func registerPairingRoutes(g *gin.RouterGroup, deps Deps) {
|
||||||
|
g.POST("/pairing/token", pairingCreateTokenHandler(deps))
|
||||||
|
g.GET("/pairing/status", pairingStatusHandler(deps))
|
||||||
|
g.GET("/pairing/tokens", pairingListTokensHandler(deps))
|
||||||
|
g.DELETE("/pairing/tokens/:token", pairingRevokeTokenHandler(deps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PairingTokenResponse 是 POST /api/pairing/token 的 data payload。
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §2:
|
||||||
|
//
|
||||||
|
// { "token": "vAc_...", "expires_at": "..." }
|
||||||
|
type PairingTokenResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// pairingCreateTokenHandler 建立一個新的 pairing token。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 從 UserContext 取出 userID(OIDC sub,由 AuthMiddleware 注入)
|
||||||
|
// 2. PairingStore.Create 產生一個合法格式的 vAc_ token
|
||||||
|
// 3. 回傳 token plaintext 給前端「只此一次」顯示
|
||||||
|
//
|
||||||
|
// 失敗回應:
|
||||||
|
// - PairingStore 未注入 → 501 NOT_IMPLEMENTED
|
||||||
|
// - Create 內部錯誤 → 500 INTERNAL_ERROR
|
||||||
|
//
|
||||||
|
// 安全提醒:plaintext 只回給合法登入的使用者,不寫進 server log。
|
||||||
|
func pairingCreateTokenHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.PairingStore == nil {
|
||||||
|
WriteNotImplemented(c, "pairing store not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix M1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// 移除 inline demo-user fallback:強制要求 AuthMiddleware 已注入合法 UserContext。
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
plaintext, info, err := deps.PairingStore.Create(ctx, userID, pairingTokenTTL)
|
||||||
|
if err != nil {
|
||||||
|
logOrDefault(deps.Logger).Error("pairing: create token failed",
|
||||||
|
"error", err,
|
||||||
|
"user_id", userID,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"failed to create pairing token", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 取 ExpiresAt — 雛形必定非 nil(pairingTokenTTL > 0),保險檢查
|
||||||
|
var expires time.Time
|
||||||
|
if info != nil && info.ExpiresAt != nil {
|
||||||
|
expires = *info.ExpiresAt
|
||||||
|
}
|
||||||
|
|
||||||
|
// 故意不 log plaintext(只 log token prefix 與 expires)
|
||||||
|
logOrDefault(deps.Logger).Info("pairing: token created",
|
||||||
|
"user_id", userID,
|
||||||
|
"token_prefix", tokenPrefix(plaintext),
|
||||||
|
"expires_at", expires,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, PairingTokenResponse{
|
||||||
|
Token: plaintext,
|
||||||
|
ExpiresAt: expires,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PairingStatusResponse 是 GET /api/pairing/status 的 data payload。
|
||||||
|
//
|
||||||
|
// 對齊 api-spec.md §2:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "connected": true,
|
||||||
|
// "connected_at": "...",
|
||||||
|
// "last_seen_at": "...",
|
||||||
|
// "device_id": "dev-xxx",
|
||||||
|
// "agent_version": "..."
|
||||||
|
// }
|
||||||
|
type PairingStatusResponse struct {
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
ConnectedAt *time.Time `json:"connected_at,omitempty"`
|
||||||
|
LastSeenAt *time.Time `json:"last_seen_at,omitempty"`
|
||||||
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
|
AgentVersion string `json:"agent_version,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// pairingStatusHandler 回報當前 user 的 tunnel 連線狀態。
|
||||||
|
//
|
||||||
|
// 雛形實作:直接 List 所有 sessions(單 user 場景),找第一個。
|
||||||
|
// 多 user 階段(B5)會改成「按 user_id 過濾」並考慮多 device 場景。
|
||||||
|
func pairingStatusHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
resp := PairingStatusResponse{Connected: false}
|
||||||
|
|
||||||
|
if deps.SessionStore == nil {
|
||||||
|
WriteSuccess(c, http.StatusOK, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
summaries, err := deps.SessionStore.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
// 跟 system/health 一樣,list 失敗不致命,回 connected=false
|
||||||
|
logOrDefault(deps.Logger).Warn("pairing/status: list sessions failed",
|
||||||
|
"error", err,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteSuccess(c, http.StatusOK, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security audit M2:寬鬆比對暫保留待人工介入修復。
|
||||||
|
// 詳細理由見 proxy.go pickActiveSessionToken 註解:relay 端 LocalHandle.Summary
|
||||||
|
// 不帶 UserID,strict equality 會讓所有 e2e proxy 鏈路全斷。
|
||||||
|
// 此處仍要求 UserContext 非空(C1 加固),但 s.UserID 暫接受空字串。
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
var picked *session.Summary
|
||||||
|
if ok && uc != nil && uc.UserID != "" {
|
||||||
|
for _, s := range summaries {
|
||||||
|
// 寬鬆比對:暫接受 s.UserID == "" 直到 relay 端 backfill UserID。
|
||||||
|
if s.UserID == "" || s.UserID == uc.UserID {
|
||||||
|
picked = s
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if picked == nil {
|
||||||
|
WriteSuccess(c, http.StatusOK, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.Connected = true
|
||||||
|
ca := picked.ConnectedAt
|
||||||
|
ls := picked.LastHeartbeat
|
||||||
|
resp.ConnectedAt = &ca
|
||||||
|
resp.LastSeenAt = &ls
|
||||||
|
resp.DeviceID = picked.DeviceID
|
||||||
|
// AgentVersion 雛形未從 tunnel 讀回;B5 會在 tunnel handshake 時收集
|
||||||
|
WriteSuccess(c, http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenPrefix 截前 8 字元用於 log,避免完整 token 進日誌。
|
||||||
|
//
|
||||||
|
// 在 pairing.go 中重複實作(不複用 relay package 的同名函式)以避免跨層循環依賴;
|
||||||
|
// 行為與 relay.tokenPrefix 一致。
|
||||||
|
func tokenPrefix(t string) string {
|
||||||
|
if len(t) <= 8 {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return t[:8]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// B5 新增:List / Revoke tokens
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// PairingTokenListItem 是 GET /api/pairing/tokens 回應中的單筆 token。
|
||||||
|
//
|
||||||
|
// **注意**:不回 Plaintext — 那只能在建立時給一次。
|
||||||
|
type PairingTokenListItem struct {
|
||||||
|
TokenPrefix string `json:"token_prefix"` // 前 12 字元,例:`vAc_7f3c8e2a`
|
||||||
|
Kind string `json:"kind"`
|
||||||
|
DeviceID string `json:"device_id,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||||
|
UsedAt *time.Time `json:"used_at,omitempty"`
|
||||||
|
RevokedAt *time.Time `json:"revoked_at,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// pairingListTokensHandler 實作 GET /api/pairing/tokens。
|
||||||
|
//
|
||||||
|
// 回當前 user 的所有 token(含已使用 / 撤銷 / 過期),供 UI 顯示。
|
||||||
|
// **絕對不回 Plaintext** — 前端已在建立時保留那份 plaintext(只此一次)。
|
||||||
|
func pairingListTokensHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.PairingStore == nil {
|
||||||
|
WriteNotImplemented(c, "pairing store not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 0.7 security fix M1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 3*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
tokens, err := deps.PairingStore.List(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"list pairing tokens failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
out := make([]PairingTokenListItem, 0, len(tokens))
|
||||||
|
for _, t := range tokens {
|
||||||
|
item := PairingTokenListItem{
|
||||||
|
TokenPrefix: tokenPrefix12(t.Plaintext),
|
||||||
|
Kind: string(t.Kind),
|
||||||
|
DeviceID: t.DeviceID,
|
||||||
|
CreatedAt: t.CreatedAt,
|
||||||
|
ExpiresAt: t.ExpiresAt,
|
||||||
|
UsedAt: t.UsedAt,
|
||||||
|
RevokedAt: t.RevokedAt,
|
||||||
|
}
|
||||||
|
out = append(out, item)
|
||||||
|
}
|
||||||
|
WriteSuccess(c, http.StatusOK, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pairingRevokeTokenHandler 實作 DELETE /api/pairing/tokens/:token。
|
||||||
|
//
|
||||||
|
// 雛形:直接收 plaintext token 作為 path param(path 可被日誌記錄 — 雛形容忍;
|
||||||
|
// Phase 1 會改為回「token id」而非 plaintext,路徑不洩漏原文)。
|
||||||
|
//
|
||||||
|
// Revoke 成功回 204 No Content;不存在回 404 + NOT_FOUND。
|
||||||
|
func pairingRevokeTokenHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
if deps.PairingStore == nil {
|
||||||
|
WriteNotImplemented(c, "pairing store not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
token := c.Param("token")
|
||||||
|
if token == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"token param required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := deps.PairingStore.Revoke(ctx, token); err != nil {
|
||||||
|
if errors.Is(err, auth.ErrInvalidToken) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "token not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"revoke token failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logOrDefault(deps.Logger).Info("pairing: token revoked",
|
||||||
|
"token_prefix", tokenPrefix(token),
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenPrefix12 截前 12 字元(`vAc_` + 8 hex),用於 list 顯示。
|
||||||
|
func tokenPrefix12(t string) string {
|
||||||
|
if len(t) <= 12 {
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
return t[:12]
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// AB11 新增:POST /api/pairing/exchange(public — 不走 AuthMiddleware)
|
||||||
|
// ==========================================================================
|
||||||
|
//
|
||||||
|
// 行為對齊 visiona-agent-tdd.md §4.3 + security.md §1.2:
|
||||||
|
// 1. agent 送 Pairing Token 過來
|
||||||
|
// 2. 雲端驗證(存在 / 未過期 / 未使用 / 未撤銷)
|
||||||
|
// 3. 產生 Session Token(vAs_ + 64 hex,90 天 TTL)
|
||||||
|
// 4. 把 Pairing Token 標為 used(一次性,無法再交換)
|
||||||
|
// 5. 回 { session_token, account, relay_url, expires_at }
|
||||||
|
//
|
||||||
|
// 雛形取捨:
|
||||||
|
// - remote-proxy 端目前只做 token 格式驗證(relay/server.go isAcceptableToken),
|
||||||
|
// **不會**實際查 SessionTokenStore。這是對齊 TDD「選項 A」的雛形設計;
|
||||||
|
// Phase 1 要新增 remote-proxy → api-server 的 `/internal/session-token/:token` 驗證。
|
||||||
|
// - Rate limit / token rotation / 真實 DB 都留給 Phase 1。
|
||||||
|
|
||||||
|
// defaultRelayPublicURL 是 relay_url 的雛形 placeholder,當 Deps.RelayPublicURL
|
||||||
|
// 未設定時用此值(讓 agent 至少能收到一個格式正確的 URL,實機請透過
|
||||||
|
// VISIONA_RELAY_PUBLIC_URL 覆寫)。
|
||||||
|
const defaultRelayPublicURL = "wss://relay.visionA.cloud"
|
||||||
|
|
||||||
|
// registerPairingPublicRoutes 註冊**不需要 auth**的 pairing endpoints。
|
||||||
|
//
|
||||||
|
// 目前只有 /api/pairing/exchange — agent 拿 Pairing Token 換 Session Token 時
|
||||||
|
// 本身還沒有登入身份,故不能套 AuthMiddleware。
|
||||||
|
// 呼叫方:NewRouter() 在 engine 層級直接註冊(不加 apiGroup)。
|
||||||
|
func registerPairingPublicRoutes(r gin.IRouter, deps Deps) {
|
||||||
|
r.POST("/api/pairing/exchange", pairingExchangeHandler(deps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PairingExchangeRequest 是 POST /api/pairing/exchange 的 request body。
|
||||||
|
type PairingExchangeRequest struct {
|
||||||
|
PairingToken string `json:"pairing_token" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PairingExchangeResponse 是 POST /api/pairing/exchange 成功時的 data payload。
|
||||||
|
//
|
||||||
|
// 欄位對齊 visiona-agent-tdd.md §7.1:
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "session_token": "vAs_...",
|
||||||
|
// "account": "demo@visionA.local",
|
||||||
|
// "relay_url": "wss://relay.visionA.cloud",
|
||||||
|
// "expires_at": "2026-07-21T00:00:00Z"
|
||||||
|
// }
|
||||||
|
type PairingExchangeResponse struct {
|
||||||
|
SessionToken string `json:"session_token"`
|
||||||
|
Account string `json:"account"`
|
||||||
|
RelayURL string `json:"relay_url"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pairing exchange 專用錯誤碼(對齊 TDD §7.1 四種 case)。
|
||||||
|
//
|
||||||
|
// 注意:這些 code 走 ErrorBody.Error.Code 欄位,不是 envelope top-level code。
|
||||||
|
// 刻意與 security.md §1.2 的語意一致;前端 agent 可以 switch 對應 UI 文案。
|
||||||
|
const (
|
||||||
|
ErrCodeInvalidPairingToken = "INVALID_PAIRING_TOKEN"
|
||||||
|
ErrCodePairingTokenExpired = "PAIRING_TOKEN_EXPIRED"
|
||||||
|
ErrCodePairingTokenUsed = "PAIRING_TOKEN_USED"
|
||||||
|
ErrCodePairingTokenRevoked = "PAIRING_TOKEN_REVOKED"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pairingExchangeHandler 實作 Pairing → Session Token 交換。
|
||||||
|
//
|
||||||
|
// 失敗回應:
|
||||||
|
// - 400 VALIDATION_FAILED — body 缺 pairing_token
|
||||||
|
// - 401 INVALID_PAIRING_TOKEN — 格式錯 / 不存在
|
||||||
|
// - 401 PAIRING_TOKEN_EXPIRED — 過期
|
||||||
|
// - 401 PAIRING_TOKEN_USED — 已交換過
|
||||||
|
// - 401 PAIRING_TOKEN_REVOKED — 已撤銷
|
||||||
|
// - 500 INTERNAL_ERROR — Session Token 產生失敗
|
||||||
|
// - 501 NOT_IMPLEMENTED — store 未注入
|
||||||
|
//
|
||||||
|
// 安全提醒:回應內**絕對不包含** Pairing Token 原文;log 只印 prefix。
|
||||||
|
func pairingExchangeHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 雛形:store 任一缺失 → 501(避免 nil pointer)
|
||||||
|
if deps.PairingStore == nil || deps.SessionTokenStore == nil {
|
||||||
|
WriteNotImplemented(c, "pairing exchange store not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse body
|
||||||
|
var req PairingExchangeRequest
|
||||||
|
if err := c.ShouldBindJSON(&req); err != nil {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed,
|
||||||
|
"pairing_token is required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式驗證 — 不合格直接 401 INVALID_PAIRING_TOKEN(避免把格式錯誤漏進 store)
|
||||||
|
if !auth.IsValidPairingToken(req.PairingToken) {
|
||||||
|
logOrDefault(deps.Logger).Warn("pairing exchange: invalid token format",
|
||||||
|
"token_prefix", tokenPrefix(req.PairingToken),
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusUnauthorized, ErrCodeInvalidPairingToken,
|
||||||
|
"pairing token format invalid", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Validate — store 會判斷存在 / 未過期 / 未使用 / 未撤銷
|
||||||
|
info, err := deps.PairingStore.Validate(ctx, req.PairingToken)
|
||||||
|
if err != nil {
|
||||||
|
code, msg := mapPairingExchangeError(err)
|
||||||
|
logOrDefault(deps.Logger).Warn("pairing exchange: validate failed",
|
||||||
|
"code", code,
|
||||||
|
"token_prefix", tokenPrefix(req.PairingToken),
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusUnauthorized, code, msg, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate Session Token
|
||||||
|
plaintext, sessionInfo, err := deps.SessionTokenStore.Create(
|
||||||
|
ctx,
|
||||||
|
info.UserID,
|
||||||
|
info.DeviceID, // Pairing Token 雛形還沒綁 device_id,為空沒關係
|
||||||
|
info.TokenHash,
|
||||||
|
auth.SessionTokenTTL,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logOrDefault(deps.Logger).Error("pairing exchange: create session token failed",
|
||||||
|
"error", err,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"failed to create session token", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark pairing token as used。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security fix M3 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// MarkUsed 失敗代表「pairing token 一次性」保證可能被破壞 — 同一個 token
|
||||||
|
// 可能再被 exchange 一次。改為 abort:撤銷剛產生的 session token、回 500,
|
||||||
|
// 而不是 silent log warn 繼續往前。
|
||||||
|
//
|
||||||
|
// 注意:deviceID 沿用 info.DeviceID(可能為空)。雛形 MarkUsed 對空字串
|
||||||
|
// 是安全的(它只是覆寫欄位)。
|
||||||
|
if err := deps.PairingStore.MarkUsed(ctx, req.PairingToken, info.DeviceID); err != nil {
|
||||||
|
// 嘗試 revoke 剛產生的 session token;revoke 自身失敗不再 retry,只 log。
|
||||||
|
revokeErr := deps.SessionTokenStore.Revoke(ctx, plaintext)
|
||||||
|
logOrDefault(deps.Logger).Error("pairing exchange: mark used failed; aborted",
|
||||||
|
"error", err,
|
||||||
|
"revoke_err", revokeErr,
|
||||||
|
"token_prefix", tokenPrefix(req.PairingToken),
|
||||||
|
"session_token_prefix", tokenPrefix(plaintext),
|
||||||
|
"device_id", info.DeviceID,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"pairing token mark-used failed; aborted", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 組 response
|
||||||
|
relayURL := deps.RelayPublicURL
|
||||||
|
if relayURL == "" {
|
||||||
|
relayURL = defaultRelayPublicURL
|
||||||
|
}
|
||||||
|
|
||||||
|
var expires time.Time
|
||||||
|
if sessionInfo.ExpiresAt != nil {
|
||||||
|
expires = *sessionInfo.ExpiresAt
|
||||||
|
} else {
|
||||||
|
expires = time.Now().UTC().Add(auth.SessionTokenTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 雛形 account — OIDC sub 通常是 UUID 不適合給人看,這裡用 userID + 固定 suffix
|
||||||
|
// 當 placeholder。Phase 1 接 DB 後可改回 session.Email(OIDC 已帶)。
|
||||||
|
account := info.UserID + "@visionA.local"
|
||||||
|
|
||||||
|
// 故意只 log token prefix,避免完整 session token 進日誌
|
||||||
|
logOrDefault(deps.Logger).Info("pairing exchange: success",
|
||||||
|
"user_id", info.UserID,
|
||||||
|
"device_id", info.DeviceID,
|
||||||
|
"session_token_prefix", tokenPrefix(plaintext),
|
||||||
|
"pairing_token_prefix", tokenPrefix(req.PairingToken),
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
WriteSuccess(c, http.StatusOK, PairingExchangeResponse{
|
||||||
|
SessionToken: plaintext,
|
||||||
|
Account: account,
|
||||||
|
RelayURL: relayURL,
|
||||||
|
ExpiresAt: expires,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapPairingExchangeError 把 PairingStore.Validate 回傳的 sentinel error 轉成
|
||||||
|
// 對應的 pairing exchange error code + 對使用者可見的訊息。
|
||||||
|
//
|
||||||
|
// 未匹配時 fallback 到 INVALID_PAIRING_TOKEN(避免洩漏內部錯誤細節)。
|
||||||
|
func mapPairingExchangeError(err error) (code, message string) {
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, auth.ErrTokenExpired):
|
||||||
|
return ErrCodePairingTokenExpired, "pairing token expired"
|
||||||
|
case errors.Is(err, auth.ErrTokenUsed):
|
||||||
|
return ErrCodePairingTokenUsed, "pairing token already used"
|
||||||
|
case errors.Is(err, auth.ErrTokenRevoked):
|
||||||
|
return ErrCodePairingTokenRevoked, "pairing token revoked"
|
||||||
|
case errors.Is(err, auth.ErrInvalidToken):
|
||||||
|
return ErrCodeInvalidPairingToken, "pairing token invalid"
|
||||||
|
default:
|
||||||
|
return ErrCodeInvalidPairingToken, "pairing token invalid"
|
||||||
|
}
|
||||||
|
}
|
||||||
323
visionA-backend/internal/api/pairing_test.go
Normal file
323
visionA-backend/internal/api/pairing_test.go
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestPairingCreateToken_OK 驗證能成功建 pairing token,且回傳格式合法。
|
||||||
|
func TestPairingCreateToken_OK(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
// Phase 0.7 security fix C1:移除 Deps.StaticUserID,改由 injectStaticUserContext 顯式注入。
|
||||||
|
registerPairingRoutes(g, Deps{
|
||||||
|
Logger: nil,
|
||||||
|
PairingStore: auth.NewInMemoryPairingStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/pairing/token", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "body: %s", w.Body.String())
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
require.True(t, body.Success)
|
||||||
|
|
||||||
|
data := body.Data.(map[string]any)
|
||||||
|
tok, _ := data["token"].(string)
|
||||||
|
assert.True(t, strings.HasPrefix(tok, "vAc_"), "token 應為 pairing 格式:%s", tok)
|
||||||
|
assert.True(t, auth.IsValidPairingToken(tok), "token 應通過格式驗證")
|
||||||
|
assert.NotEmpty(t, data["expires_at"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingCreateToken_NoStore 驗證沒注入 PairingStore 時回 501。
|
||||||
|
func TestPairingCreateToken_NoStore(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerPairingRoutes(g, Deps{Logger: nil})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/pairing/token", nil))
|
||||||
|
assert.Equal(t, 501, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingStatus_NoSession 驗證沒 session 時回 connected=false。
|
||||||
|
func TestPairingStatus_NoSession(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerPairingRoutes(g, Deps{
|
||||||
|
SessionStore: &fakeSessionStore{},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/pairing/status", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
data := body.Data.(map[string]any)
|
||||||
|
assert.Equal(t, false, data["connected"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// AB11:POST /api/pairing/exchange 測試
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// setupExchangeRouter 建立一個只掛 exchange endpoint 的 minimal router。
|
||||||
|
//
|
||||||
|
// 重點:exchange **不走** AuthMiddleware,故不掛 AuthMiddleware。
|
||||||
|
// 這也反映了 production 的 NewRouter 實際行為(registerPairingPublicRoutes 在
|
||||||
|
// engine 層註冊,而不是 apiGroup)。
|
||||||
|
func setupExchangeRouter(t *testing.T, deps Deps) *gin.Engine {
|
||||||
|
t.Helper()
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
registerPairingPublicRoutes(r, deps)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// issuePairingToken 建一個合法 pairing token 供 exchange 測試用。
|
||||||
|
func issuePairingToken(t *testing.T, store auth.PairingStore, userID string, ttl time.Duration) string {
|
||||||
|
t.Helper()
|
||||||
|
plain, _, err := store.Create(context.Background(), userID, ttl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return plain
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_OK 驗證 happy path:拿合法 pairing token 換到 session token。
|
||||||
|
func TestPairingExchange_OK(t *testing.T) {
|
||||||
|
pairings := auth.NewInMemoryPairingStore()
|
||||||
|
sessions := auth.NewInMemorySessionTokenStore()
|
||||||
|
pairingTok := issuePairingToken(t, pairings, "demo-user", 15*time.Minute)
|
||||||
|
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: pairings,
|
||||||
|
SessionTokenStore: sessions,
|
||||||
|
RelayPublicURL: "wss://relay.test.local",
|
||||||
|
})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "body: %s", w.Body.String())
|
||||||
|
|
||||||
|
var resp SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
require.True(t, resp.Success)
|
||||||
|
|
||||||
|
data := resp.Data.(map[string]any)
|
||||||
|
sessTok, _ := data["session_token"].(string)
|
||||||
|
assert.True(t, auth.IsValidSessionToken(sessTok), "session_token 應為合法 vAs_ 格式:%s", sessTok)
|
||||||
|
assert.Equal(t, "wss://relay.test.local", data["relay_url"])
|
||||||
|
assert.Equal(t, "demo-user@visionA.local", data["account"])
|
||||||
|
assert.NotEmpty(t, data["expires_at"])
|
||||||
|
|
||||||
|
// Session token 應能從 store 查到
|
||||||
|
_, err := sessions.Get(context.Background(), sessTok)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Pairing token 應被標為 used;再 exchange 一次應該失敗
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
req2 := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req2.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w2, req2)
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w2.Code)
|
||||||
|
assert.Contains(t, w2.Body.String(), ErrCodePairingTokenUsed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_InvalidFormat 驗證格式錯的 token 回 401 INVALID_PAIRING_TOKEN。
|
||||||
|
func TestPairingExchange_InvalidFormat(t *testing.T) {
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: auth.NewInMemoryPairingStore(),
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 格式錯(缺前綴)
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: "not-a-real-token"})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeInvalidPairingToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_MissingField 驗證 body 沒 pairing_token 回 400 VALIDATION_FAILED。
|
||||||
|
func TestPairingExchange_MissingField(t *testing.T) {
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: auth.NewInMemoryPairingStore(),
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange",
|
||||||
|
strings.NewReader(`{}`))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeValidationFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_Unknown 驗證合法格式但 store 找不到的 token 回 401 INVALID_PAIRING_TOKEN。
|
||||||
|
func TestPairingExchange_Unknown(t *testing.T) {
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: auth.NewInMemoryPairingStore(),
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 格式合法但 store 沒存過
|
||||||
|
unknown, err := auth.GeneratePairingToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: unknown})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeInvalidPairingToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_Expired 驗證過期 token 回 401 PAIRING_TOKEN_EXPIRED。
|
||||||
|
func TestPairingExchange_Expired(t *testing.T) {
|
||||||
|
pairings := auth.NewInMemoryPairingStore()
|
||||||
|
// TTL 1ns → 幾乎立刻過期
|
||||||
|
pairingTok := issuePairingToken(t, pairings, "demo-user", 1*time.Nanosecond)
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: pairings,
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodePairingTokenExpired)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_Revoked 驗證撤銷 token 回 401 PAIRING_TOKEN_REVOKED。
|
||||||
|
func TestPairingExchange_Revoked(t *testing.T) {
|
||||||
|
pairings := auth.NewInMemoryPairingStore()
|
||||||
|
pairingTok := issuePairingToken(t, pairings, "demo-user", 15*time.Minute)
|
||||||
|
require.NoError(t, pairings.Revoke(context.Background(), pairingTok))
|
||||||
|
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: pairings,
|
||||||
|
SessionTokenStore: auth.NewInMemorySessionTokenStore(),
|
||||||
|
})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodePairingTokenRevoked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_NoStore 驗證 SessionTokenStore / PairingStore 缺失時回 501。
|
||||||
|
func TestPairingExchange_NoStore(t *testing.T) {
|
||||||
|
r := setupExchangeRouter(t, Deps{}) // 兩個 store 都 nil
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: "vAc_" + strings.Repeat("0", 32)})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
assert.Equal(t, 501, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingExchange_DefaultRelayURL 驗證沒設 RelayPublicURL 時會 fallback 到 placeholder。
|
||||||
|
func TestPairingExchange_DefaultRelayURL(t *testing.T) {
|
||||||
|
pairings := auth.NewInMemoryPairingStore()
|
||||||
|
sessions := auth.NewInMemorySessionTokenStore()
|
||||||
|
pairingTok := issuePairingToken(t, pairings, "demo-user", 15*time.Minute)
|
||||||
|
|
||||||
|
r := setupExchangeRouter(t, Deps{
|
||||||
|
PairingStore: pairings,
|
||||||
|
SessionTokenStore: sessions,
|
||||||
|
// RelayPublicURL 刻意留空
|
||||||
|
})
|
||||||
|
|
||||||
|
body, _ := json.Marshal(PairingExchangeRequest{PairingToken: pairingTok})
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/pairing/exchange", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
require.Equal(t, http.StatusOK, w.Code, "body: %s", w.Body.String())
|
||||||
|
|
||||||
|
var resp SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
|
||||||
|
data := resp.Data.(map[string]any)
|
||||||
|
assert.Equal(t, defaultRelayPublicURL, data["relay_url"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPairingStatus_WithSession 驗證有 session 時回 connected=true + 對應欄位。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security fix M2:pairingStatusHandler 已改為 strict equality,
|
||||||
|
// 必須顯式注入 UserContext 才能拿到匹配 session(不再走「空 UserID 視為 match」捷徑)。
|
||||||
|
func TestPairingStatus_WithSession(t *testing.T) {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
registerPairingRoutes(g, Deps{
|
||||||
|
SessionStore: &fakeSessionStore{
|
||||||
|
sessions: []*session.Summary{
|
||||||
|
{
|
||||||
|
Token: "vAc_a",
|
||||||
|
UserID: "demo-user",
|
||||||
|
DeviceID: "dev-1",
|
||||||
|
ConnectedAt: now.Add(-1 * time.Hour),
|
||||||
|
LastHeartbeat: now,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodGet, "/api/pairing/status", nil))
|
||||||
|
require.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
var body SuccessBody
|
||||||
|
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body))
|
||||||
|
data := body.Data.(map[string]any)
|
||||||
|
assert.Equal(t, true, data["connected"])
|
||||||
|
assert.Equal(t, "dev-1", data["device_id"])
|
||||||
|
assert.NotEmpty(t, data["connected_at"])
|
||||||
|
assert.NotEmpty(t, data["last_seen_at"])
|
||||||
|
}
|
||||||
255
visionA-backend/internal/api/proxy.go
Normal file
255
visionA-backend/internal/api/proxy.go
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
// proxy.go — 「把 gin 請求轉發到 local agent」的共用邏輯。
|
||||||
|
//
|
||||||
|
// 大量 device / camera / media / model load-to-device endpoint 都會走同一條路徑:
|
||||||
|
// 1. 從 UserContext 拿到當前使用者
|
||||||
|
// 2. 透過 SessionStore / ProxyClient 找到該使用者的 active session token
|
||||||
|
// 3. 用 Forwarder.ForwardHTTP 代理請求(body / headers / path 原樣送)
|
||||||
|
// 4. 把 response 原樣寫回 gin.ResponseWriter(支援 streaming)
|
||||||
|
//
|
||||||
|
// 把這段抽成 handler 產生器,讓 devices.go / camera.go 等只需宣告路徑即可。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultProxyRequestTimeout 是「非 streaming 型」proxy 請求的整體 timeout。
|
||||||
|
//
|
||||||
|
// 對 streaming 端點(MJPEG / SSE)不套用此 timeout — 我們靠 gin 的 ctx 取消機制
|
||||||
|
// 在 browser 關閉時順帶關 conn。300s 對 scan / flash(可能很慢)夠寬鬆。
|
||||||
|
const defaultProxyRequestTimeout = 300 * time.Second
|
||||||
|
|
||||||
|
// proxyOptions 控制 proxy handler 的細部行為。
|
||||||
|
type proxyOptions struct {
|
||||||
|
// streaming 若為 true 代表 response body 可能是長連線(MJPEG / SSE);
|
||||||
|
// 這種情況下我們不套 timeout、並對 gin.Writer.Flush 啟用 chunk 推送。
|
||||||
|
streaming bool
|
||||||
|
|
||||||
|
// rewritePath 可選:若非空,就把請求 path 改寫成這個值再送到 local agent。
|
||||||
|
// 雛形大多不需要(api-server 的路徑與 local agent 的路徑一致)。
|
||||||
|
rewritePath string
|
||||||
|
}
|
||||||
|
|
||||||
|
// newProxyHandler 產生一個 gin.HandlerFunc,會把當前請求透過 Forwarder 轉發到
|
||||||
|
// local agent(由 UserContext 對應的 active session 決定)。
|
||||||
|
//
|
||||||
|
// 用法:
|
||||||
|
//
|
||||||
|
// g.GET("/devices", newProxyHandler(deps, proxyOptions{}))
|
||||||
|
// g.GET("/camera/stream", newProxyHandler(deps, proxyOptions{streaming: true}))
|
||||||
|
func newProxyHandler(deps Deps, opts proxyOptions) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// 1. 檢查必要依賴
|
||||||
|
if deps.Forwarder == nil || deps.SessionStore == nil {
|
||||||
|
WriteNotImplemented(c, "forwarder/session store not configured")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 找當前使用者的 active session token
|
||||||
|
// Phase 0.7 security fix C1 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// 移除 demo-user fallback:apiGroup 下所有 handler 都被 AuthMiddleware 保護,
|
||||||
|
// 拿不到 UserContext 代表 middleware 設定錯誤,回 500 比 silent fallback 安全。
|
||||||
|
uc, ok := UserContextFrom(c)
|
||||||
|
if !ok || uc.UserID == "" {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"missing user context (auth middleware misconfigured?)", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userID := uc.UserID
|
||||||
|
|
||||||
|
token, err := pickActiveSessionToken(c.Request.Context(), deps.SessionStore, userID, deps.Logger)
|
||||||
|
if err != nil {
|
||||||
|
writeTunnelError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 決定 rewrite path(可選)
|
||||||
|
outPath := c.Request.URL.Path
|
||||||
|
if opts.rewritePath != "" {
|
||||||
|
outPath = opts.rewritePath
|
||||||
|
}
|
||||||
|
if c.Request.URL.RawQuery != "" {
|
||||||
|
outPath += "?" + c.Request.URL.RawQuery
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 組出「打給 local agent」的 http.Request
|
||||||
|
ctx := c.Request.Context()
|
||||||
|
if !opts.streaming {
|
||||||
|
// 對非 streaming 端點加個總 timeout,避免 local agent hang 住
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithTimeout(ctx, defaultProxyRequestTimeout)
|
||||||
|
defer cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
outReq, err := http.NewRequestWithContext(ctx, c.Request.Method, outPath, c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"proxy: build upstream request: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 複製 headers;過濾掉 hop-by-hop(Forwarder 不會動,但避免重複)
|
||||||
|
copyProxyRequestHeaders(c.Request.Header, outReq.Header)
|
||||||
|
// Content-Length 要保留
|
||||||
|
if cl := c.Request.ContentLength; cl > 0 {
|
||||||
|
outReq.ContentLength = cl
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 呼叫 Forwarder
|
||||||
|
resp, err := deps.Forwarder.ForwardHTTP(ctx, token, outReq)
|
||||||
|
if err != nil {
|
||||||
|
writeTunnelError(c, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
// 6. 把 response 寫回 gin.Writer
|
||||||
|
writeProxyResponse(c, resp, opts.streaming)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pickActiveSessionToken 找出當前使用者在雲端的 active session token。
|
||||||
|
//
|
||||||
|
// 雛形邏輯(單一 user + 單一 agent):走 Store.List,過濾 userID 對得上的第一筆。
|
||||||
|
// OIDC 模式下 userID 是 Member Center 簽出的 OIDC sub(UUID),tunnel session 在
|
||||||
|
// pairing exchange 時被綁到同個 sub,因此能對上。
|
||||||
|
//
|
||||||
|
// 多 user / 多 device 階段(Phase 1)需要 store.ListByUser(userID) 原生介面,
|
||||||
|
// 見 session.Store TODO。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security audit M2 (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)
|
||||||
|
// **保留寬鬆比對待人工介入修復**:
|
||||||
|
// - 完整修法是「s.UserID != "" && s.UserID == userID」strict equality
|
||||||
|
// - 但 prototype 的 relay.NewLocalHandle (internal/relay/local_handle.go:31)
|
||||||
|
// 在 tunnel handshake 時不查 SessionTokenStore,所以 Summary.UserID 永遠為空
|
||||||
|
// - 改 strict 會讓所有 e2e proxy 鏈路全斷(TestE2E_FullFlow_PairingToForward 等)
|
||||||
|
// - 正解需 relay 端在 HandleTunnelConnect 時拿 token 查 SessionTokenStore
|
||||||
|
// 取得 user_id 並寫入 LocalHandle.summary.UserID(屬 Phase 1 follow-up)
|
||||||
|
//
|
||||||
|
// 暫保留寬鬆比對;C1/M1 handler-side strict UserContext 已優先處理 — 任何 request
|
||||||
|
// 進入此函式時 userID 必非空(handler 在前面已 abort 500),所以唯一仍寬鬆的條件是
|
||||||
|
// s.UserID == ""(relay-side 尚未 backfill)。
|
||||||
|
//
|
||||||
|
// logger 參數保留給未來觀測(list 失敗時 log warn),目前尚未使用;測試傳 nil 即可。
|
||||||
|
func pickActiveSessionToken(ctx context.Context, store session.Store, userID string, _ any) (string, error) {
|
||||||
|
listCtx, cancel := context.WithTimeout(ctx, 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
summaries, err := store.List(listCtx)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("proxy: list sessions: %w", err)
|
||||||
|
}
|
||||||
|
if len(summaries) == 0 {
|
||||||
|
return "", session.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range summaries {
|
||||||
|
// 寬鬆比對:handler 已確保 userID 非空(C1 strict mode);
|
||||||
|
// 暫接受 s.UserID == "" 直到 relay 端 backfill UserID(M2 待人工介入)。
|
||||||
|
if s.UserID == "" || s.UserID == userID {
|
||||||
|
return s.Token, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", session.ErrSessionNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeTunnelError 把 forwarder / store 的錯誤映射到統一的 API 錯誤格式。
|
||||||
|
//
|
||||||
|
// - ErrSessionNotFound / ErrSessionClosed → 502 TUNNEL_DISCONNECTED
|
||||||
|
// - 其他 → 502 TUNNEL_ERROR(本質上是 local agent 不可達)
|
||||||
|
func writeTunnelError(c *gin.Context, err error) {
|
||||||
|
if errors.Is(err, session.ErrSessionNotFound) || errors.Is(err, session.ErrSessionClosed) {
|
||||||
|
WriteError(c, http.StatusBadGateway, ErrCodeTunnelDisconnect,
|
||||||
|
"local agent 未連線或 tunnel 斷開", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusBadGateway, ErrCodeTunnelError,
|
||||||
|
"tunnel error: "+err.Error(), nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// copyProxyRequestHeaders 把 src 的 headers 複製到 dst,但略過 hop-by-hop。
|
||||||
|
//
|
||||||
|
// 對齊 RFC 7230 §6.1 hop-by-hop headers:
|
||||||
|
//
|
||||||
|
// Connection, Keep-Alive, Proxy-Authenticate, Proxy-Authorization,
|
||||||
|
// TE, Trailers, Transfer-Encoding, Upgrade
|
||||||
|
//
|
||||||
|
// 這些由 Forwarder / underlying conn 自動處理,不該 blind copy。
|
||||||
|
func copyProxyRequestHeaders(src, dst http.Header) {
|
||||||
|
for name, values := range src {
|
||||||
|
if isHopByHopHeader(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Authorization header 雛形不必送(local agent 沒有對應的 auth 系統);
|
||||||
|
// 但保留其他 custom header(X-From-Api 等 test fixture 會用)
|
||||||
|
if strings.EqualFold(name, "Authorization") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
dst.Add(name, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isHopByHopHeader 回報 header 名稱是否為 hop-by-hop。
|
||||||
|
func isHopByHopHeader(name string) bool {
|
||||||
|
switch strings.ToLower(name) {
|
||||||
|
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
|
||||||
|
"te", "trailers", "transfer-encoding", "upgrade":
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeProxyResponse 把 upstream response 原樣寫回 gin。
|
||||||
|
//
|
||||||
|
// 支援 streaming:若 streaming=true 且 response 有 Flusher,每次 Read 後立即 Flush。
|
||||||
|
// 這讓 MJPEG / SSE 的 frame 能即時抵達 browser。
|
||||||
|
func writeProxyResponse(c *gin.Context, resp *http.Response, streaming bool) {
|
||||||
|
// 複製 headers(略過 hop-by-hop)
|
||||||
|
for name, values := range resp.Header {
|
||||||
|
if isHopByHopHeader(name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, v := range values {
|
||||||
|
c.Writer.Header().Add(name, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
if !streaming {
|
||||||
|
// 非 streaming:一口氣 copy 完
|
||||||
|
_, _ = io.Copy(c.Writer, resp.Body)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Streaming:邊讀邊 flush。buffer 大小 8KB,平衡延遲與 syscall 次數。
|
||||||
|
buf := make([]byte, 8*1024)
|
||||||
|
flusher, _ := c.Writer.(http.Flusher)
|
||||||
|
for {
|
||||||
|
n, rerr := resp.Body.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
if _, werr := c.Writer.Write(buf[:n]); werr != nil {
|
||||||
|
// browser 斷線 → 停止(conn 會在 resp.Body.Close 時關掉 upstream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
// io.EOF 或連線結束都是正常
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
69
visionA-backend/internal/api/proxy_test.go
Normal file
69
visionA-backend/internal/api/proxy_test.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestNewProxyHandler_NoForwarder 驗證沒注入 Forwarder 時回 501。
|
||||||
|
func TestNewProxyHandler_NoForwarder(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
g := r.Group("/api")
|
||||||
|
g.POST("/devices/scan", newProxyHandler(Deps{}, proxyOptions{}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/devices/scan", nil))
|
||||||
|
assert.Equal(t, http.StatusNotImplemented, w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewProxyHandler_TunnelDisconnected 驗證沒 session 時回 502 TUNNEL_DISCONNECTED。
|
||||||
|
//
|
||||||
|
// 這裡用 fakeSessionStore(List 回空)+ 非 nil forwarder 的「半個」 proxy handler;
|
||||||
|
// 因為 nil forwarder 的 path 會先 return 501(見上方 test)。我們用真實 forwarder
|
||||||
|
// 但不 dial — 直接在 pickActiveSessionToken 回 ErrSessionNotFound 就攔掉了。
|
||||||
|
//
|
||||||
|
// Phase 0.7 security fix C1:handler 強制要求 UserContext;用 injectStaticUserContext
|
||||||
|
// 顯式注入避免回 500(見 .autoflow/05-implementation/review/phase-0.7-security-audit.md)。
|
||||||
|
func TestNewProxyHandler_TunnelDisconnected(t *testing.T) {
|
||||||
|
r := gin.New()
|
||||||
|
r.Use(RequestIDMiddleware())
|
||||||
|
r.Use(injectStaticUserContext("demo-user", ""))
|
||||||
|
g := r.Group("/api")
|
||||||
|
g.POST("/devices/scan", newProxyHandler(Deps{
|
||||||
|
SessionStore: &fakeSessionStore{}, // List 回空
|
||||||
|
Forwarder: session.NewForwarder("http://localhost:0", nil),
|
||||||
|
}, proxyOptions{}))
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, httptest.NewRequest(http.MethodPost, "/api/devices/scan", nil))
|
||||||
|
assert.Equal(t, http.StatusBadGateway, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), ErrCodeTunnelDisconnect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPickActiveSessionToken 驗證 helper 回第一筆 match 的 session。
|
||||||
|
func TestPickActiveSessionToken(t *testing.T) {
|
||||||
|
store := &fakeSessionStore{
|
||||||
|
sessions: []*session.Summary{
|
||||||
|
{Token: "t-other", UserID: "other"},
|
||||||
|
{Token: "t-me", UserID: "demo-user"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tok, err := pickActiveSessionToken(context.Background(), store, "demo-user", nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "t-me", tok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPickActiveSessionToken_Empty 驗證沒 session 時回 ErrSessionNotFound。
|
||||||
|
func TestPickActiveSessionToken_Empty(t *testing.T) {
|
||||||
|
store := &fakeSessionStore{}
|
||||||
|
_, err := pickActiveSessionToken(context.Background(), store, "demo-user", nil)
|
||||||
|
assert.ErrorIs(t, err, session.ErrSessionNotFound)
|
||||||
|
}
|
||||||
134
visionA-backend/internal/api/storage.go
Normal file
134
visionA-backend/internal/api/storage.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
// storage.go — /storage/* 的假 presigned URL 代理(雛形 LocalFS 用)。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// - 前端拿到 /api/models/init 回來的 upload_url(例:http://localhost:3721/storage/models/xxx.nef?expires=...&signature=...)
|
||||||
|
// - 直接對該 URL 發 PUT(body = 檔案內容)
|
||||||
|
// - 此 handler 驗簽 → 呼叫 Storage.Put 寫入 LocalFS
|
||||||
|
//
|
||||||
|
// GET 路徑對稱:驗簽 → Storage.Get → 串流回瀏覽器
|
||||||
|
//
|
||||||
|
// Phase 1 切換成 S3 後,整個 /storage/* handler 就可刪除
|
||||||
|
// (前端直接 PUT 到 S3 presigned URL,不經過 api-server)。
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerStorageRoutes 註冊 /storage/* proxy。**不在 /api/ 底下**,對齊 api-spec.md §10。
|
||||||
|
//
|
||||||
|
// 由 NewRouter 呼叫(不透過 AuthMiddleware — 因為已經用 HMAC 簽章控管存取)。
|
||||||
|
func registerStorageRoutes(r *gin.Engine, deps Deps) {
|
||||||
|
if deps.Storage == nil {
|
||||||
|
// 沒 storage 時就不註冊這條路由
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.GET("/storage/*filepath", storageGetHandler(deps))
|
||||||
|
r.PUT("/storage/*filepath", storagePutHandler(deps))
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyStorageSignature 從 query 抽 expires / signature 並呼叫 LocalFSStore.VerifySignature。
|
||||||
|
//
|
||||||
|
// Storage interface 本身沒有 VerifySignature 方法(那是 LocalFS 專用),
|
||||||
|
// 所以這裡用 type assertion 抓到 *LocalFSStore 再驗。
|
||||||
|
// Phase 1 S3 的 presigned URL 驗證由 S3 自己處理 — api-server 不會收到這些請求。
|
||||||
|
func verifyStorageSignature(c *gin.Context, deps Deps, method, key string) error {
|
||||||
|
ls, ok := deps.Storage.(*storage.LocalFSStore)
|
||||||
|
if !ok {
|
||||||
|
return storage.ErrInvalidSignature // 非 LocalFS 不應走這條 endpoint
|
||||||
|
}
|
||||||
|
expiresStr := c.Query("expires")
|
||||||
|
sig := c.Query("signature")
|
||||||
|
if expiresStr == "" || sig == "" {
|
||||||
|
return storage.ErrInvalidSignature
|
||||||
|
}
|
||||||
|
expires, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return storage.ErrInvalidSignature
|
||||||
|
}
|
||||||
|
return ls.VerifySignature(method, key, expires, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
// storageKeyFromPath 把 /storage/*filepath 的 filepath 截出來(gin 會帶前導 "/")。
|
||||||
|
func storageKeyFromPath(p string) string {
|
||||||
|
return strings.TrimPrefix(p, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// storageGetHandler 實作 GET /storage/*filepath。
|
||||||
|
//
|
||||||
|
// 驗簽 → Stat 取 size / mtime → Get 串流。
|
||||||
|
// 對 streaming-friendly:用 io.Copy 直接寫入 ResponseWriter。
|
||||||
|
func storageGetHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
key := storageKeyFromPath(c.Param("filepath"))
|
||||||
|
if key == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "empty key", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := verifyStorageSignature(c, deps, "GET", key); err != nil {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeInvalidSignature,
|
||||||
|
"invalid or expired signature", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reader, obj, err := deps.Storage.Get(c.Request.Context(), key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, storage.ErrNotFound) {
|
||||||
|
WriteError(c, http.StatusNotFound, ErrCodeNotFound, "object not found", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"get storage failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
c.Writer.Header().Set("Content-Type", obj.ContentType)
|
||||||
|
c.Writer.Header().Set("Content-Length", strconv.FormatInt(obj.Size, 10))
|
||||||
|
c.Writer.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.Copy(c.Writer, reader)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// storagePutHandler 實作 PUT /storage/*filepath。
|
||||||
|
//
|
||||||
|
// 驗簽 → 讀 body → Storage.Put。
|
||||||
|
//
|
||||||
|
// 請求大小限制:雛形不在此強制(前端已經在 /api/models/init 被擋過 MaxUploadSizeMB);
|
||||||
|
// 若要守第二道,可在此檢查 c.Request.ContentLength。
|
||||||
|
func storagePutHandler(deps Deps) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
key := storageKeyFromPath(c.Param("filepath"))
|
||||||
|
if key == "" {
|
||||||
|
WriteError(c, http.StatusBadRequest, ErrCodeValidationFailed, "empty key", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := verifyStorageSignature(c, deps, "PUT", key); err != nil {
|
||||||
|
WriteError(c, http.StatusForbidden, ErrCodeInvalidSignature,
|
||||||
|
"invalid or expired signature", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 寫入 storage
|
||||||
|
if err := deps.Storage.Put(c.Request.Context(), key, c.Request.Body, c.Request.ContentLength, nil); err != nil {
|
||||||
|
WriteError(c, http.StatusInternalServerError, ErrCodeInternalError,
|
||||||
|
"put storage failed: "+err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logOrDefault(deps.Logger).Info("storage: put",
|
||||||
|
"key", key,
|
||||||
|
"size", c.Request.ContentLength,
|
||||||
|
"request_id", RequestIDFrom(c))
|
||||||
|
|
||||||
|
c.Status(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
85
visionA-backend/internal/api/stubs.go
Normal file
85
visionA-backend/internal/api/stubs.go
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// registerStubRoutes 註冊 B5 尚未實作 / Phase 1 才處理的 endpoint,一律回 501 NOT_IMPLEMENTED。
|
||||||
|
//
|
||||||
|
// **為什麼只留這些**:Auth / Pairing 補齊 / Devices / Models / GET /clusters /
|
||||||
|
// system/deps / /storage 都在 B5 補實作(見 auth.go / devices.go / models.go /
|
||||||
|
// clusters.go / storage.go 各檔)。這裡只剩:
|
||||||
|
// - Cloud 裝置記錄(非 tunnel 的 CRUD,Phase 1)
|
||||||
|
// - Clusters 寫入類(Phase 1)
|
||||||
|
// - Camera / Media(走 tunnel proxy;B5 先不實作以避免過度擴張,B7 補)
|
||||||
|
// - Converter(Phase 1)
|
||||||
|
// - WebSocket endpoints(B7 TODO — 需要 Hijack + WS relay)
|
||||||
|
//
|
||||||
|
// 讓前端對錯誤路徑能拿到 501 而非 404,減少除錯成本。
|
||||||
|
func registerStubRoutes(g *gin.RouterGroup, _ Deps) {
|
||||||
|
stub := func(hint string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
WriteNotImplemented(c, hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Cloud-side device records(非 tunnel) ---
|
||||||
|
g.GET("/cloud/devices", stub("cloud.devices.list — pending Phase 1"))
|
||||||
|
g.POST("/cloud/devices/:id/rename", stub("cloud.devices.rename — pending Phase 1"))
|
||||||
|
|
||||||
|
// --- Models 其餘 ---
|
||||||
|
// load-to-device 已在 models.go 註冊為 501 stub(為了讓 /api/models/:id/load-to-device 路徑註冊完整)。
|
||||||
|
|
||||||
|
// --- Clusters 寫入類 ---
|
||||||
|
g.POST("/clusters", stub("clusters.create — pending Phase 1"))
|
||||||
|
g.GET("/clusters/:id", stub("clusters.get — pending Phase 1"))
|
||||||
|
g.DELETE("/clusters/:id", stub("clusters.delete — pending Phase 1"))
|
||||||
|
g.POST("/clusters/:id/devices", stub("clusters.add-device — pending Phase 1"))
|
||||||
|
g.DELETE("/clusters/:id/devices/:deviceId", stub("clusters.remove-device — pending Phase 1"))
|
||||||
|
g.PUT("/clusters/:id/devices/:deviceId/weight", stub("clusters.set-weight — pending Phase 1"))
|
||||||
|
g.POST("/clusters/:id/flash", stub("clusters.flash — pending Phase 1"))
|
||||||
|
g.POST("/clusters/:id/inference/start", stub("clusters.inference.start — pending Phase 1"))
|
||||||
|
g.POST("/clusters/:id/inference/stop", stub("clusters.inference.stop — pending Phase 1"))
|
||||||
|
|
||||||
|
// --- Camera / Media(B7 補;走 tunnel proxy) ---
|
||||||
|
g.GET("/camera/list", stub("camera.list via tunnel — pending B7"))
|
||||||
|
g.POST("/camera/start", stub("camera.start via tunnel — pending B7"))
|
||||||
|
g.POST("/camera/stop", stub("camera.stop via tunnel — pending B7"))
|
||||||
|
g.GET("/camera/stream", stub("camera.stream MJPEG via tunnel — pending B7"))
|
||||||
|
g.POST("/media/upload/image", stub("media.upload.image — pending B7"))
|
||||||
|
g.POST("/media/upload/video", stub("media.upload.video — pending B7"))
|
||||||
|
g.POST("/media/upload/batch-images", stub("media.upload.batch — pending B7"))
|
||||||
|
g.GET("/media/batch-images/:index", stub("media.batch.get — pending B7"))
|
||||||
|
g.POST("/media/seek", stub("media.seek — pending B7"))
|
||||||
|
|
||||||
|
// --- Converter(Phase 1) ---
|
||||||
|
g.POST("/converter/jobs", stub("converter.submit — pending Phase 1"))
|
||||||
|
g.GET("/converter/jobs", stub("converter.list — pending Phase 1"))
|
||||||
|
g.GET("/converter/jobs/:id", stub("converter.get — pending Phase 1"))
|
||||||
|
g.GET("/converter/jobs/:id/download", stub("converter.download — pending Phase 1"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// registerWebSocketStubs 註冊 /ws/* 的 stub。WebSocket proxy 在 B5 雛形不實作,
|
||||||
|
// 留 501 讓前端能收到明確錯誤,由 B7 補齊。
|
||||||
|
//
|
||||||
|
// 為什麼不做 WS proxy:實作 WS relay 需要在 api-server 端做 Hijack、雙向 io.Copy,
|
||||||
|
// 而且 Forwarder.ForwardWebSocket 尚未實作(見 forwarder.go §ForwardWebSocket)。
|
||||||
|
// 加這條路徑會顯著擴張 B5 範圍;按 prompt 指示先留 TODO。
|
||||||
|
//
|
||||||
|
// 注意:ws endpoint 在 /ws 而非 /api/ws,所以由 NewRouter 直接註冊而非 apiGroup。
|
||||||
|
func registerWebSocketStubs(r *gin.Engine) {
|
||||||
|
stub := func(hint string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
WriteNotImplemented(c, hint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 用 GET(WebSocket upgrade 的初始 HTTP request)
|
||||||
|
r.GET("/ws/devices/events", stub("ws.devices.events — pending B7"))
|
||||||
|
r.GET("/ws/devices/:id/flash-progress", stub("ws.flash-progress — pending B7"))
|
||||||
|
r.GET("/ws/devices/:id/inference", stub("ws.inference — pending B7"))
|
||||||
|
r.GET("/ws/server-logs", stub("ws.server-logs — pending B7"))
|
||||||
|
r.GET("/ws/system", stub("ws.system — pending B7"))
|
||||||
|
r.GET("/ws/clusters/:id/inference", stub("ws.clusters.inference — pending B7"))
|
||||||
|
r.GET("/ws/clusters/:id/flash-progress", stub("ws.clusters.flash — pending B7"))
|
||||||
|
r.GET("/ws/pairing/status", stub("ws.pairing.status — pending B7"))
|
||||||
|
}
|
||||||
33
visionA-backend/internal/api/test_helpers_test.go
Normal file
33
visionA-backend/internal/api/test_helpers_test.go
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
// test_helpers_test.go — internal/api 套件 unit test 共用 helper。
|
||||||
|
//
|
||||||
|
// OB5 起 AuthMiddleware 強制走 OIDC(cookie + SessionManager),
|
||||||
|
// 但很多既有 unit test 並不關心 auth 細節 — 它們關心的是「假設 user 已登入,
|
||||||
|
// 那該 handler 行為對不對」。為了讓這類測試不被 auth 細節拖累,
|
||||||
|
// 提供一個「跳過 AuthMiddleware、直接塞 UserContext」的 middleware shim。
|
||||||
|
//
|
||||||
|
// 完整的 OIDC 認證流程測試見:
|
||||||
|
// - oidc_auth_test.go(unit test,含 AuthMiddleware 行為)
|
||||||
|
// - cmd/api-server/oidc_e2e_test.go(end-to-end)
|
||||||
|
|
||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// injectStaticUserContext 是 unit test 用的 fake AuthMiddleware:
|
||||||
|
// 直接把指定的 userID / email 塞進 gin.Context,跳過 cookie / session 邏輯。
|
||||||
|
//
|
||||||
|
// 用途:測試 handler 在「假設 user 已登入」前提下的行為(list / create 等業務邏輯)。
|
||||||
|
// 不可用於:驗證 AuthMiddleware 自身行為 — 那要走真 OIDC flow。
|
||||||
|
func injectStaticUserContext(userID, email string) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
c.Set(ctxKeyUserContext, &auth.UserContext{
|
||||||
|
UserID: userID,
|
||||||
|
Email: email,
|
||||||
|
})
|
||||||
|
c.Next()
|
||||||
|
}
|
||||||
|
}
|
||||||
232
visionA-backend/internal/auth/auth.go
Normal file
232
visionA-backend/internal/auth/auth.go
Normal file
@ -0,0 +1,232 @@
|
|||||||
|
// Package auth 定義 visionA-backend 的雙層 Auth 介面(AuthService / AuthProvider)
|
||||||
|
// 以及 Pairing / Session Token 的型別與 Store。
|
||||||
|
//
|
||||||
|
// 介面設計對齊 TDD §2.2、security.md §2.0 與 PRD interface-contracts.md §8.2。
|
||||||
|
//
|
||||||
|
// 從 Phase 0.6(OB5)起,唯一的認證路徑是 OIDC + cookie session(見 internal/oidc/
|
||||||
|
// 與 internal/usersession/),因此本 package 不再提供 AuthProvider / AuthService 的
|
||||||
|
// 內建實作。介面本身仍保留,供未來新增備援 provider(例:Phase 1 接 backup local auth、
|
||||||
|
// service-to-service token)時直接套用,不必重新發明 contract。
|
||||||
|
//
|
||||||
|
// 本 package 仍負責提供:
|
||||||
|
// - UserContext 等領域型別
|
||||||
|
// - PairingToken / SessionToken 結構與其 Store 介面
|
||||||
|
// - PairingStore / SessionTokenStore 的 in-memory 實作
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// 錯誤型別(公開 sentinel errors,便於 caller 用 errors.Is 比對)
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotImplemented 用於雛形 stub,表示此功能 Phase 0 尚未實作。
|
||||||
|
ErrNotImplemented = errors.New("auth: not implemented in phase 0")
|
||||||
|
|
||||||
|
// ErrInvalidToken 表示 token 格式錯誤、過期、或不被此 provider 認識。
|
||||||
|
ErrInvalidToken = errors.New("auth: invalid token")
|
||||||
|
|
||||||
|
// ErrTokenExpired 表示 token 過了 ExpiresAt。
|
||||||
|
ErrTokenExpired = errors.New("auth: token expired")
|
||||||
|
|
||||||
|
// ErrTokenUsed 表示一次性 token(pairing)已經被消費。
|
||||||
|
ErrTokenUsed = errors.New("auth: token already used")
|
||||||
|
|
||||||
|
// ErrTokenRevoked 表示 token 已被使用者或系統撤銷。
|
||||||
|
ErrTokenRevoked = errors.New("auth: token revoked")
|
||||||
|
|
||||||
|
// ErrInvalidCredentials 表示 email / password 比對失敗(Phase 1 實作)。
|
||||||
|
ErrInvalidCredentials = errors.New("auth: invalid credentials")
|
||||||
|
|
||||||
|
// ErrUserNotFound 表示查詢的 user 不存在。
|
||||||
|
ErrUserNotFound = errors.New("auth: user not found")
|
||||||
|
|
||||||
|
// ErrUserAlreadyExists 表示註冊時 email 已存在(Phase 1 實作)。
|
||||||
|
ErrUserAlreadyExists = errors.New("auth: user already exists")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Domain types
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// User 是 Auth 系統觀察到的使用者。雛形固定為 demo-user;Phase 1 對接真實 DB。
|
||||||
|
//
|
||||||
|
// 註:完整 User struct 定義於 database.md §2.1;這裡保留 auth 層所需欄位即可。
|
||||||
|
type User struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserContext 是從 request 解析出來、後續 handler 可信賴的使用者身分資訊。
|
||||||
|
//
|
||||||
|
// Middleware 層(AuthService.Authenticate)負責產生此 context;
|
||||||
|
// Handler 不需知道 token 從哪來,只需讀 UserContext。
|
||||||
|
type UserContext struct {
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
Email string `json:"email,omitempty"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
OrgID string `json:"orgId,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginRequest 是 Login 的輸入參數。
|
||||||
|
type LoginRequest struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginResult 是 Login 成功後回傳的完整資訊。
|
||||||
|
type LoginResult struct {
|
||||||
|
User *User `json:"user"`
|
||||||
|
AccessToken string `json:"accessToken"`
|
||||||
|
RefreshToken string `json:"refreshToken,omitempty"`
|
||||||
|
ExpiresAt time.Time `json:"expiresAt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterRequest 是 Register 的輸入參數。
|
||||||
|
type RegisterRequest struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Token types
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// TokenKind 區分 token 的生命週期類型。
|
||||||
|
type TokenKind string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// KindPairing 是短期一次性 token(15 min TTL),用於首次 agent 配對。
|
||||||
|
KindPairing TokenKind = "pairing"
|
||||||
|
|
||||||
|
// KindSession 是長期可撤銷 token(90 天 TTL),agent 升級後使用。
|
||||||
|
KindSession TokenKind = "session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PairingToken 代表一個已發行(尚未消費)的 pairing token 紀錄。
|
||||||
|
//
|
||||||
|
// 格式:vAc_ + 32 hex(共 36 字元);見 security.md §1.3。
|
||||||
|
// DB 僅存 TokenHash(sha256 plaintext),原文 token 僅在建立時回傳一次。
|
||||||
|
//
|
||||||
|
// 雛形 InMemoryPairingStore 存的是明文 token 作為 key,Phase 1 改為 hash。
|
||||||
|
type PairingToken struct {
|
||||||
|
// Plaintext 是原文 token(僅在建立時回傳給 caller;儲存層請改存 hash)。
|
||||||
|
Plaintext string `json:"-"`
|
||||||
|
// TokenHash 是 sha256(Plaintext) 的 hex 表示,DB 實際 PK。
|
||||||
|
TokenHash string `json:"-"`
|
||||||
|
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
DeviceID string `json:"deviceId,omitempty"`
|
||||||
|
|
||||||
|
Kind TokenKind `json:"kind"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||||
|
UsedAt *time.Time `json:"usedAt,omitempty"`
|
||||||
|
RevokedAt *time.Time `json:"revokedAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired 回報此 token 是否已過 ExpiresAt。
|
||||||
|
// ExpiresAt 為 nil 代表永不過期(僅 Phase 1 的 session token 可能如此設定)。
|
||||||
|
func (t *PairingToken) IsExpired(now time.Time) bool {
|
||||||
|
if t.ExpiresAt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return now.After(*t.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsUsed 回報此 token 是否已被消費(一次性 pairing token 用)。
|
||||||
|
func (t *PairingToken) IsUsed() bool {
|
||||||
|
return t.UsedAt != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRevoked 回報此 token 是否已撤銷。
|
||||||
|
func (t *PairingToken) IsRevoked() bool {
|
||||||
|
return t.RevokedAt != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionToken 代表升級後的長期 tunnel session token(Phase 1 使用)。
|
||||||
|
//
|
||||||
|
// 格式:vAs_ + 64 hex(共 68 字元);見 security.md §1.3。
|
||||||
|
// 雛形階段以單階段 pairing token 代替,故 SessionToken struct 目前主要作為型別佔位。
|
||||||
|
type SessionToken struct {
|
||||||
|
Plaintext string `json:"-"`
|
||||||
|
TokenHash string `json:"-"`
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
DeviceID string `json:"deviceId"`
|
||||||
|
ParentTokenHash string `json:"-"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
ExpiresAt *time.Time `json:"expiresAt,omitempty"`
|
||||||
|
RevokedAt *time.Time `json:"revokedAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Interfaces
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// AuthService 是 middleware 層介面:每個 HTTP request 進來時由它解析身分。
|
||||||
|
//
|
||||||
|
// 實作時通常 wrap 一個 AuthProvider 或直接讀 cookie / header。
|
||||||
|
// 雛形 StaticAuthService 永遠回 demo-user,方便前端開發時不需要真正登入。
|
||||||
|
type AuthService interface {
|
||||||
|
// Authenticate 從 HTTP request 解析出 UserContext。
|
||||||
|
// 若無法認證,回傳具體錯誤(例:ErrInvalidToken);middleware 應將其轉為 401。
|
||||||
|
Authenticate(ctx context.Context, r *http.Request) (*UserContext, error)
|
||||||
|
|
||||||
|
// Authorize 判斷此 UserContext 是否有權對 resource 做 action。
|
||||||
|
// 雛形回 nil(全放);Phase 1 以 RBAC 實作。
|
||||||
|
Authorize(ctx context.Context, uc *UserContext, resource, action string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthProvider 是 handler 層介面:處理登入 / 註冊 / 登出 / token 驗證等明確動作。
|
||||||
|
//
|
||||||
|
// 此介面對齊 PRD interface-contracts.md §8.2。雛形以 StaticAuthProvider 填入,
|
||||||
|
// Phase 1 換為 JWTAuthProvider(綁 DB + JWT 簽章)不影響呼叫端。
|
||||||
|
type AuthProvider interface {
|
||||||
|
Register(ctx context.Context, req *RegisterRequest) (*User, error)
|
||||||
|
Login(ctx context.Context, req *LoginRequest) (*LoginResult, error)
|
||||||
|
Logout(ctx context.Context, token string) error
|
||||||
|
ValidateToken(ctx context.Context, token string) (*UserContext, error)
|
||||||
|
GetUser(ctx context.Context, userID string) (*User, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PairingStore 管理 pairing token 的生命週期。
|
||||||
|
//
|
||||||
|
// Phase 0 使用 InMemoryPairingStore(map + mutex + TTL 清理);
|
||||||
|
// Phase 1 改為 PostgresPairingStore 並加入兩階段升級(pairing → session)。
|
||||||
|
//
|
||||||
|
// 注意一次性使用的語意:MarkUsed 後 Validate 必須失敗。
|
||||||
|
type PairingStore interface {
|
||||||
|
// Create 產生並保存一個新的 pairing token。
|
||||||
|
// plaintext 為原文 token(caller 只此一次能拿到),info 為儲存層表示,err 為產生錯誤。
|
||||||
|
Create(ctx context.Context, userID string, ttl time.Duration) (plaintext string, info *PairingToken, err error)
|
||||||
|
|
||||||
|
// Validate 檢查 token 是否有效(存在、未過期、未被使用、未被撤銷)。
|
||||||
|
// 驗證通過回傳 token 資訊,否則回具體錯誤(ErrInvalidToken / ErrTokenExpired / ...)。
|
||||||
|
Validate(ctx context.Context, token string) (*PairingToken, error)
|
||||||
|
|
||||||
|
// MarkUsed 標記一次性 token 為已使用;之後 Validate 必須失敗。
|
||||||
|
MarkUsed(ctx context.Context, token string, deviceID string) error
|
||||||
|
|
||||||
|
// Revoke 撤銷一個 token(使用者操作或系統判定)。
|
||||||
|
Revoke(ctx context.Context, token string) error
|
||||||
|
|
||||||
|
// List 列出某使用者的所有 pairing token(含已使用 / 已撤銷,供 UI 顯示)。
|
||||||
|
List(ctx context.Context, userID string) ([]*PairingToken, error)
|
||||||
|
|
||||||
|
// CleanupExpired 清除超過 ExpiresAt 的 token;
|
||||||
|
// 由 background goroutine 週期性呼叫(建議每分鐘)。
|
||||||
|
// 回傳被移除的數量,便於觀測。
|
||||||
|
CleanupExpired(ctx context.Context, now time.Time) (removed int, err error)
|
||||||
|
}
|
||||||
159
visionA-backend/internal/auth/inmemory_pairing_store.go
Normal file
159
visionA-backend/internal/auth/inmemory_pairing_store.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InMemoryPairingStore 是 PairingStore 的記憶體實作,用於 Phase 0 雛形與單元測試。
|
||||||
|
//
|
||||||
|
// 設計要點:
|
||||||
|
// - 以 plaintext token 為 map key(雛形圖簡;Phase 1 的 PostgresPairingStore 會改存 hash)
|
||||||
|
// - sync.RWMutex 確保並發安全
|
||||||
|
// - 一次性語意:MarkUsed 後 Validate 會回 ErrTokenUsed
|
||||||
|
// - TTL 語意:超過 ExpiresAt 後 Validate 回 ErrTokenExpired;CleanupExpired 會移除
|
||||||
|
type InMemoryPairingStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
tokens map[string]*PairingToken // key = plaintext token
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemoryPairingStore 建立一個空的記憶體 PairingStore。
|
||||||
|
func NewInMemoryPairingStore() *InMemoryPairingStore {
|
||||||
|
return &InMemoryPairingStore{
|
||||||
|
tokens: make(map[string]*PairingToken),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 產生並保存一個新 pairing token。
|
||||||
|
//
|
||||||
|
// ttl 為相對存活時間;內部以目前時間 + ttl 算出 ExpiresAt。
|
||||||
|
// 若 ttl <= 0,則 ExpiresAt 保持 nil(永不過期;僅測試 / Phase 1 特殊情境使用)。
|
||||||
|
func (s *InMemoryPairingStore) Create(
|
||||||
|
ctx context.Context, userID string, ttl time.Duration,
|
||||||
|
) (string, *PairingToken, error) {
|
||||||
|
plaintext, err := GeneratePairingToken()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
info := &PairingToken{
|
||||||
|
Plaintext: plaintext,
|
||||||
|
TokenHash: HashToken(plaintext),
|
||||||
|
UserID: userID,
|
||||||
|
Kind: KindPairing,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
if ttl > 0 {
|
||||||
|
expires := now.Add(ttl)
|
||||||
|
info.ExpiresAt = &expires
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.tokens[plaintext] = info
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
// 回傳的 info 給 caller 用(不含 Plaintext 避免誤寫入 log)。
|
||||||
|
// 但為了讓 caller 能立刻傳給前端顯示一次,Plaintext 保留。
|
||||||
|
// 呼叫方有責任不記錄 info.Plaintext 到持久化日誌。
|
||||||
|
return plaintext, info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 檢查 token 是否存在且可用(未過期、未消費、未撤銷)。
|
||||||
|
func (s *InMemoryPairingStore) Validate(ctx context.Context, token string) (*PairingToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
info, ok := s.tokens[token]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
if info.IsRevoked() {
|
||||||
|
return nil, ErrTokenRevoked
|
||||||
|
}
|
||||||
|
if info.IsUsed() {
|
||||||
|
return nil, ErrTokenUsed
|
||||||
|
}
|
||||||
|
if info.IsExpired(time.Now().UTC()) {
|
||||||
|
return nil, ErrTokenExpired
|
||||||
|
}
|
||||||
|
// 回傳 copy 避免 caller 誤改內部狀態(map value 是 pointer,複製 struct)。
|
||||||
|
cp := *info
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarkUsed 將 token 標記為已消費,並綁定 deviceID。
|
||||||
|
//
|
||||||
|
// 若 token 不存在回 ErrInvalidToken;若已標記過則為 no-op(冪等)。
|
||||||
|
func (s *InMemoryPairingStore) MarkUsed(ctx context.Context, token, deviceID string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
info, ok := s.tokens[token]
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
if info.UsedAt != nil {
|
||||||
|
// 已使用 — 冪等回 nil,但不覆寫 deviceID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
info.UsedAt = &now
|
||||||
|
info.DeviceID = deviceID
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke 撤銷一個 token(Validate 後會回 ErrTokenRevoked)。
|
||||||
|
func (s *InMemoryPairingStore) Revoke(ctx context.Context, token string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
info, ok := s.tokens[token]
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
if info.RevokedAt != nil {
|
||||||
|
return nil // 冪等
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
info.RevokedAt = &now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 回傳指定 user 的所有 pairing token(含已使用 / 撤銷)。
|
||||||
|
//
|
||||||
|
// 注意:回傳的 slice 為 copy,但 Plaintext 欄位也被複製 — Caller 應避免記錄。
|
||||||
|
func (s *InMemoryPairingStore) List(ctx context.Context, userID string) ([]*PairingToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
out := make([]*PairingToken, 0)
|
||||||
|
for _, info := range s.tokens {
|
||||||
|
if info.UserID == userID {
|
||||||
|
cp := *info
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpired 移除所有已過 ExpiresAt 的 token;回傳移除數量。
|
||||||
|
//
|
||||||
|
// 通常由 background goroutine 週期性呼叫(例:每 1 分鐘)。
|
||||||
|
func (s *InMemoryPairingStore) CleanupExpired(ctx context.Context, now time.Time) (int, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
removed := 0
|
||||||
|
for k, info := range s.tokens {
|
||||||
|
if info.IsExpired(now) {
|
||||||
|
delete(s.tokens, k)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 InMemoryPairingStore 實作 PairingStore。
|
||||||
|
var _ PairingStore = (*InMemoryPairingStore)(nil)
|
||||||
132
visionA-backend/internal/auth/inmemory_pairing_store_test.go
Normal file
132
visionA-backend/internal/auth/inmemory_pairing_store_test.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_CreateAndValidate(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
plain, info, err := s.Create(ctx, "user-1", 15*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, plain)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
|
||||||
|
assert.True(t, IsValidPairingToken(plain))
|
||||||
|
assert.Equal(t, "user-1", info.UserID)
|
||||||
|
assert.Equal(t, KindPairing, info.Kind)
|
||||||
|
assert.NotNil(t, info.ExpiresAt)
|
||||||
|
assert.Nil(t, info.UsedAt)
|
||||||
|
|
||||||
|
got, err := s.Validate(ctx, plain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user-1", got.UserID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_Validate_UnknownToken(t *testing.T) {
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
_, err := s.Validate(context.Background(), "vAc_unknown0000000000000000000000")
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_MarkUsed_IsOneTime(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
plain, _, err := s.Create(ctx, "user-1", 15*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.MarkUsed(ctx, plain, "device-1"))
|
||||||
|
|
||||||
|
// Validate 必須失敗(一次性 token 已消費)。
|
||||||
|
_, err = s.Validate(ctx, plain)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenUsed)
|
||||||
|
|
||||||
|
// 再次 MarkUsed 應為 no-op(冪等)。
|
||||||
|
assert.NoError(t, s.MarkUsed(ctx, plain, "another-device"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_Revoke(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
plain, _, err := s.Create(ctx, "user-1", 15*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, s.Revoke(ctx, plain))
|
||||||
|
|
||||||
|
_, err = s.Validate(ctx, plain)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenRevoked)
|
||||||
|
|
||||||
|
// 撤銷不存在的 token → ErrInvalidToken
|
||||||
|
assert.ErrorIs(t, s.Revoke(ctx, "vAc_abcdef00000000000000000000000000"), ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_CleanupExpired(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
// 產生一個已過期的 token(ttl = 1ms)
|
||||||
|
expired, _, err := s.Create(ctx, "user-1", 1*time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 另一個尚未過期
|
||||||
|
fresh, _, err := s.Create(ctx, "user-1", 1*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 等 10ms 確保第一個過期
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
removed, err := s.CleanupExpired(ctx, time.Now().UTC())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, removed)
|
||||||
|
|
||||||
|
_, err = s.Validate(ctx, expired)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidToken, "過期的 token 應被清掉")
|
||||||
|
|
||||||
|
_, err = s.Validate(ctx, fresh)
|
||||||
|
assert.NoError(t, err, "未過期的 token 不應被清")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_List_ByUser(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
_, _, err := s.Create(ctx, "user-A", time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, _, err = s.Create(ctx, "user-A", time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, _, err = s.Create(ctx, "user-B", time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
listA, err := s.List(ctx, "user-A")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listA, 2)
|
||||||
|
|
||||||
|
listB, err := s.List(ctx, "user-B")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listB, 1)
|
||||||
|
|
||||||
|
listNone, err := s.List(ctx, "user-X")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, listNone)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryPairingStore_Validate_Expired(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryPairingStore()
|
||||||
|
|
||||||
|
plain, _, err := s.Create(ctx, "user-1", 1*time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
|
||||||
|
_, err = s.Validate(ctx, plain)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenExpired)
|
||||||
|
}
|
||||||
160
visionA-backend/internal/auth/session_token.go
Normal file
160
visionA-backend/internal/auth/session_token.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// SessionTokenStore
|
||||||
|
// ==========================================================================
|
||||||
|
//
|
||||||
|
// 對齊 security.md §1.3 / visiona-agent-tdd.md §4.3:
|
||||||
|
// - Pairing Token(vAc_ + 32 hex)15 min TTL,一次性。
|
||||||
|
// - Session Token(vAs_ + 64 hex)90 天 TTL,長期可撤銷。
|
||||||
|
//
|
||||||
|
// 本 store 負責「Pairing → Session」交換後發出的 Session Token 生命週期管理。
|
||||||
|
// 雛形(Phase 0)以 in-memory map 持有;Phase 1 換為 Postgres 時維持介面。
|
||||||
|
//
|
||||||
|
// 注意:雛形 remote-proxy 目前只做 token 格式驗證(見 relay/server.go
|
||||||
|
// isAcceptableToken),**不會**實際查 SessionTokenStore。這是刻意的雛形取捨,
|
||||||
|
// 對應 visiona-agent-tdd.md 的「選項 A」。Phase 1 會新增
|
||||||
|
// `GET /internal/session-token/:token` 讓 remote-proxy 拉驗證。
|
||||||
|
|
||||||
|
// SessionTokenTTL 是 Session Token 的預設存活時間(對齊 security.md §1.3)。
|
||||||
|
const SessionTokenTTL = 90 * 24 * time.Hour
|
||||||
|
|
||||||
|
// SessionTokenStore 管理 Session Token 的生命週期。
|
||||||
|
//
|
||||||
|
// 實作必須是 goroutine-safe;雛形使用 InMemorySessionTokenStore。
|
||||||
|
type SessionTokenStore interface {
|
||||||
|
// Create 產生並保存一個新的 Session Token。
|
||||||
|
//
|
||||||
|
// ttl 為相對存活時間;若 <= 0 視為「無過期時間」。
|
||||||
|
// plaintext 為原文 token(caller 只此一次能拿到)。
|
||||||
|
Create(ctx context.Context, userID, deviceID, parentTokenHash string, ttl time.Duration) (plaintext string, info *SessionToken, err error)
|
||||||
|
|
||||||
|
// Get 依 plaintext 查詢 Session Token;token 不存在回 ErrInvalidToken,
|
||||||
|
// 過期回 ErrTokenExpired,已撤銷回 ErrTokenRevoked。
|
||||||
|
//
|
||||||
|
// 回傳的 SessionToken 為 copy,caller 不可直接改內部狀態。
|
||||||
|
Get(ctx context.Context, plaintext string) (*SessionToken, error)
|
||||||
|
|
||||||
|
// Revoke 撤銷一個 Session Token;之後 Get 會回 ErrTokenRevoked。
|
||||||
|
// 若 token 不存在回 ErrInvalidToken;已撤銷為冪等(回 nil)。
|
||||||
|
Revoke(ctx context.Context, plaintext string) error
|
||||||
|
|
||||||
|
// CleanupExpired 移除所有已過期的 token,回傳移除數量。
|
||||||
|
// 由 background goroutine 週期性呼叫;雛形暫無呼叫處。
|
||||||
|
CleanupExpired(ctx context.Context, now time.Time) (removed int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// InMemorySessionTokenStore
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// InMemorySessionTokenStore 是 SessionTokenStore 的雛形記憶體實作。
|
||||||
|
//
|
||||||
|
// 設計要點(刻意對齊 InMemoryPairingStore 風格):
|
||||||
|
// - 以 plaintext token 為 map key(Phase 1 改 hash)
|
||||||
|
// - sync.RWMutex 保護並發存取
|
||||||
|
// - ExpiresAt 為 nil 代表永不過期
|
||||||
|
type InMemorySessionTokenStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
tokens map[string]*SessionToken // key = plaintext token
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemorySessionTokenStore 建立一個空的記憶體 SessionTokenStore。
|
||||||
|
func NewInMemorySessionTokenStore() *InMemorySessionTokenStore {
|
||||||
|
return &InMemorySessionTokenStore{
|
||||||
|
tokens: make(map[string]*SessionToken),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 產生並保存一個新 Session Token。
|
||||||
|
//
|
||||||
|
// parentTokenHash 為升級來源(通常是 Pairing Token 的 hash),方便 Phase 1
|
||||||
|
// 做稽核追蹤;雛形 caller 傳空字串也可以。
|
||||||
|
func (s *InMemorySessionTokenStore) Create(
|
||||||
|
ctx context.Context, userID, deviceID, parentTokenHash string, ttl time.Duration,
|
||||||
|
) (string, *SessionToken, error) {
|
||||||
|
plaintext, err := GenerateSessionToken()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
info := &SessionToken{
|
||||||
|
Plaintext: plaintext,
|
||||||
|
TokenHash: HashToken(plaintext),
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
ParentTokenHash: parentTokenHash,
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
if ttl > 0 {
|
||||||
|
expires := now.Add(ttl)
|
||||||
|
info.ExpiresAt = &expires
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.tokens[plaintext] = info
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
|
return plaintext, info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 查詢 Session Token;回傳前會檢查過期 / 撤銷狀態。
|
||||||
|
func (s *InMemorySessionTokenStore) Get(ctx context.Context, plaintext string) (*SessionToken, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
info, ok := s.tokens[plaintext]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrInvalidToken
|
||||||
|
}
|
||||||
|
if info.RevokedAt != nil {
|
||||||
|
return nil, ErrTokenRevoked
|
||||||
|
}
|
||||||
|
if info.ExpiresAt != nil && time.Now().UTC().After(*info.ExpiresAt) {
|
||||||
|
return nil, ErrTokenExpired
|
||||||
|
}
|
||||||
|
cp := *info
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Revoke 撤銷 Session Token;之後 Get 會回 ErrTokenRevoked。
|
||||||
|
func (s *InMemorySessionTokenStore) Revoke(ctx context.Context, plaintext string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
info, ok := s.tokens[plaintext]
|
||||||
|
if !ok {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
if info.RevokedAt != nil {
|
||||||
|
return nil // 冪等
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
info.RevokedAt = &now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpired 移除所有已過期(ExpiresAt < now)的 token。
|
||||||
|
func (s *InMemorySessionTokenStore) CleanupExpired(ctx context.Context, now time.Time) (int, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
removed := 0
|
||||||
|
for k, info := range s.tokens {
|
||||||
|
if info.ExpiresAt != nil && now.After(*info.ExpiresAt) {
|
||||||
|
delete(s.tokens, k)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 InMemorySessionTokenStore 實作 SessionTokenStore。
|
||||||
|
var _ SessionTokenStore = (*InMemorySessionTokenStore)(nil)
|
||||||
109
visionA-backend/internal/auth/session_token_test.go
Normal file
109
visionA-backend/internal/auth/session_token_test.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_CreateAndGet 驗證一次完整的建立 → 查詢循環。
|
||||||
|
func TestInMemorySessionTokenStore_CreateAndGet(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
plain, info, err := s.Create(ctx, "user-1", "dev-1", "parent-hash", SessionTokenTTL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, IsValidSessionToken(plain), "產出 token 應通過格式驗證:%s", plain)
|
||||||
|
require.NotNil(t, info)
|
||||||
|
assert.Equal(t, "user-1", info.UserID)
|
||||||
|
assert.Equal(t, "dev-1", info.DeviceID)
|
||||||
|
assert.Equal(t, "parent-hash", info.ParentTokenHash)
|
||||||
|
require.NotNil(t, info.ExpiresAt)
|
||||||
|
assert.WithinDuration(t, time.Now().UTC().Add(SessionTokenTTL), *info.ExpiresAt, 2*time.Second)
|
||||||
|
|
||||||
|
got, err := s.Get(ctx, plain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user-1", got.UserID)
|
||||||
|
assert.Equal(t, info.TokenHash, got.TokenHash)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_Get_NotFound 驗證查詢不存在 token 回 ErrInvalidToken。
|
||||||
|
func TestInMemorySessionTokenStore_Get_NotFound(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
_, err := s.Get(context.Background(), "vAs_deadbeef")
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_Get_Expired 驗證過期 token 回 ErrTokenExpired。
|
||||||
|
func TestInMemorySessionTokenStore_Get_Expired(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// TTL 設 1ns 確保立即過期
|
||||||
|
plain, _, err := s.Create(ctx, "u", "d", "", 1*time.Nanosecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
_, err = s.Get(ctx, plain)
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenExpired), "應回 ErrTokenExpired,實際:%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_Revoke 驗證撤銷後 Get 回 ErrTokenRevoked。
|
||||||
|
func TestInMemorySessionTokenStore_Revoke(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
plain, _, err := s.Create(ctx, "u", "d", "", SessionTokenTTL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, s.Revoke(ctx, plain))
|
||||||
|
|
||||||
|
_, err = s.Get(ctx, plain)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenRevoked)
|
||||||
|
|
||||||
|
// 冪等:再撤一次不該報錯
|
||||||
|
assert.NoError(t, s.Revoke(ctx, plain))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_Revoke_NotFound 驗證撤銷不存在 token 回 ErrInvalidToken。
|
||||||
|
func TestInMemorySessionTokenStore_Revoke_NotFound(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
err := s.Revoke(context.Background(), "vAs_nope")
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_CleanupExpired 驗證過期 token 會被清掉。
|
||||||
|
func TestInMemorySessionTokenStore_CleanupExpired(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// 一個會過期、一個長效
|
||||||
|
expiredTok, _, err := s.Create(ctx, "u1", "d1", "", 1*time.Nanosecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
freshTok, _, err := s.Create(ctx, "u2", "d2", "", SessionTokenTTL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
removed, err := s.CleanupExpired(ctx, time.Now().UTC())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, removed)
|
||||||
|
|
||||||
|
// 過期的應查不到
|
||||||
|
_, err = s.Get(ctx, expiredTok)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidToken)
|
||||||
|
|
||||||
|
// 新鮮的仍在
|
||||||
|
_, err = s.Get(ctx, freshTok)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemorySessionTokenStore_NeverExpires 驗證 ttl <= 0 時 ExpiresAt 為 nil。
|
||||||
|
func TestInMemorySessionTokenStore_NeverExpires(t *testing.T) {
|
||||||
|
s := NewInMemorySessionTokenStore()
|
||||||
|
_, info, err := s.Create(context.Background(), "u", "d", "", 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, info.ExpiresAt, "ttl=0 時 ExpiresAt 應為 nil")
|
||||||
|
}
|
||||||
66
visionA-backend/internal/auth/token.go
Normal file
66
visionA-backend/internal/auth/token.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"regexp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Token prefix 常數(security.md §1.3)。
|
||||||
|
const (
|
||||||
|
// PairingTokenPrefix 是 pairing token 的固定前綴。
|
||||||
|
PairingTokenPrefix = "vAc_"
|
||||||
|
// SessionTokenPrefix 是 session token 的固定前綴。
|
||||||
|
SessionTokenPrefix = "vAs_"
|
||||||
|
|
||||||
|
// PairingTokenHexLen 是 pairing token 底 hex 字串的字元數(32 chars = 16 bytes)。
|
||||||
|
PairingTokenHexLen = 32
|
||||||
|
// SessionTokenHexLen 是 session token 底 hex 字串的字元數(64 chars = 32 bytes)。
|
||||||
|
SessionTokenHexLen = 64
|
||||||
|
)
|
||||||
|
|
||||||
|
// pairingTokenRegex 驗證 vAc_ + 32 小寫 hex 的完整格式。
|
||||||
|
var pairingTokenRegex = regexp.MustCompile(`^vAc_[0-9a-f]{32}$`)
|
||||||
|
|
||||||
|
// sessionTokenRegex 驗證 vAs_ + 64 小寫 hex 的完整格式。
|
||||||
|
var sessionTokenRegex = regexp.MustCompile(`^vAs_[0-9a-f]{64}$`)
|
||||||
|
|
||||||
|
// GeneratePairingToken 產生一個符合 `vAc_[0-9a-f]{32}` 格式的 pairing token。
|
||||||
|
//
|
||||||
|
// 來源:crypto/rand.Read 16 bytes → hex 編碼。
|
||||||
|
// 失敗時回傳 err(通常僅在系統熵耗盡時發生)。
|
||||||
|
func GeneratePairingToken() (string, error) {
|
||||||
|
b := make([]byte, PairingTokenHexLen/2) // 16 bytes = 32 hex chars
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return PairingTokenPrefix + hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateSessionToken 產生一個符合 `vAs_[0-9a-f]{64}` 格式的 session token。
|
||||||
|
//
|
||||||
|
// 來源:crypto/rand.Read 32 bytes → hex 編碼。
|
||||||
|
func GenerateSessionToken() (string, error) {
|
||||||
|
b := make([]byte, SessionTokenHexLen/2) // 32 bytes = 64 hex chars
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return SessionTokenPrefix + hex.EncodeToString(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidPairingToken 驗證 token 字串是否符合 pairing token 格式。
|
||||||
|
func IsValidPairingToken(token string) bool {
|
||||||
|
return pairingTokenRegex.MatchString(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidSessionToken 驗證 token 字串是否符合 session token 格式。
|
||||||
|
func IsValidSessionToken(token string) bool {
|
||||||
|
return sessionTokenRegex.MatchString(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashToken 計算 token 的 sha256 hex 字串,供 DB 儲存用(永遠不存明文)。
|
||||||
|
func HashToken(token string) string {
|
||||||
|
h := sha256.Sum256([]byte(token))
|
||||||
|
return hex.EncodeToString(h[:])
|
||||||
|
}
|
||||||
73
visionA-backend/internal/auth/token_test.go
Normal file
73
visionA-backend/internal/auth/token_test.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGeneratePairingToken_Format(t *testing.T) {
|
||||||
|
tok, err := GeneratePairingToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, IsValidPairingToken(tok), "產生的 token 應符合 pairing 正則:got %q", tok)
|
||||||
|
assert.Len(t, tok, len(PairingTokenPrefix)+PairingTokenHexLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateSessionToken_Format(t *testing.T) {
|
||||||
|
tok, err := GenerateSessionToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, IsValidSessionToken(tok), "產生的 token 應符合 session 正則:got %q", tok)
|
||||||
|
assert.Len(t, tok, len(SessionTokenPrefix)+SessionTokenHexLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGeneratePairingToken_Unique(t *testing.T) {
|
||||||
|
// 產生 100 次應不會碰撞(熵極低)。
|
||||||
|
seen := make(map[string]struct{}, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
tok, err := GeneratePairingToken()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, dup := seen[tok]
|
||||||
|
require.False(t, dup, "pairing token 不應重複產生:%s", tok)
|
||||||
|
seen[tok] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidPairingToken(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
valid bool
|
||||||
|
}{
|
||||||
|
{"正確格式", "vAc_0123456789abcdef0123456789abcdef", true},
|
||||||
|
{"大寫 hex 不允許", "vAc_0123456789ABCDEF0123456789ABCDEF", false},
|
||||||
|
{"前綴錯誤", "vAs_0123456789abcdef0123456789abcdef", false},
|
||||||
|
{"長度太短", "vAc_0123456789abcdef", false},
|
||||||
|
{"長度太長", "vAc_0123456789abcdef0123456789abcdef00", false},
|
||||||
|
{"含非 hex 字元", "vAc_0123456789abcdef0123456789abcdeZ0", false},
|
||||||
|
{"空字串", "", false},
|
||||||
|
{"僅前綴", "vAc_", false},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.valid, IsValidPairingToken(tc.token))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsValidSessionToken(t *testing.T) {
|
||||||
|
goodSession := "vAs_" + "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"
|
||||||
|
assert.True(t, IsValidSessionToken(goodSession))
|
||||||
|
assert.False(t, IsValidSessionToken("vAs_short"))
|
||||||
|
assert.False(t, IsValidSessionToken("vAc_0123456789abcdef0123456789abcdef"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHashToken_Deterministic(t *testing.T) {
|
||||||
|
h1 := HashToken("vAc_abc")
|
||||||
|
h2 := HashToken("vAc_abc")
|
||||||
|
assert.Equal(t, h1, h2, "同樣輸入應得同樣 hash")
|
||||||
|
assert.NotEqual(t, HashToken("vAc_abc"), HashToken("vAc_def"))
|
||||||
|
assert.Len(t, h1, 64, "sha256 hex 長度應為 64")
|
||||||
|
}
|
||||||
0
visionA-backend/internal/cluster/.gitkeep
Normal file
0
visionA-backend/internal/cluster/.gitkeep
Normal file
18
visionA-backend/internal/cluster/TODO.md
Normal file
18
visionA-backend/internal/cluster/TODO.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# cluster package — 待補項目
|
||||||
|
|
||||||
|
雛形 B3 僅複製 POC `types.go`(去除 driver 相依)。以下 POC 檔案尚未搬過來:
|
||||||
|
|
||||||
|
| POC 檔案 | 狀態 | 說明 |
|
||||||
|
|---------|------|------|
|
||||||
|
| `dispatcher.go` | TODO | Weighted Round-Robin dispatcher;依賴 `driver.DeviceDriver` interface |
|
||||||
|
| `manager.go` | TODO | 叢集生命週期管理(Add/Remove/ModelUpdate);依賴 device / driver |
|
||||||
|
| `pipeline.go` | TODO | 推論 pipeline(結果 merge / order);依賴 `driver.InferenceResult` |
|
||||||
|
|
||||||
|
## 需要 B5 / B6 討論的選項
|
||||||
|
|
||||||
|
- 選項 A:把 `driver.DeviceDriver` 抽成 interface 搬進 `internal/device/driver.go`
|
||||||
|
- 選項 B:cluster 改為純 pass-through — 雲端只管「叢集定義」,dispatch / pipeline
|
||||||
|
完全交給 local agent;雲端只保留 CRUD + 狀態聚合
|
||||||
|
- 選項 C:雲端完全不做 dispatch,clusters API 僅提供 CRUD(最輕量)
|
||||||
|
|
||||||
|
**建議預設 C,等 PM / Architect 在 B5 前確認**(對齊 design-doc 的「雲端不做業務邏輯」原則)。
|
||||||
79
visionA-backend/internal/cluster/types.go
Normal file
79
visionA-backend/internal/cluster/types.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
// Package cluster 提供多裝置平行推論的叢集(cluster)資料結構。
|
||||||
|
//
|
||||||
|
// ⚠️ 雛形階段(Phase 0):此 package 僅包含 types.go 的 domain 資料結構,
|
||||||
|
// 方便後續 B5 的 `/api/clusters/*` handler 使用;
|
||||||
|
// POC 的 Dispatcher / Manager / Pipeline 實作因深度依賴 `driver.DeviceDriver`
|
||||||
|
// 這類 POC 端特有型別,搬過來會破壞 `internal/device` 的乾淨 domain model,
|
||||||
|
// 故先不搬,留 TODO 待 B5 視需求決定:
|
||||||
|
// - 選項 A:把 driver 抽成 interface 搬過來
|
||||||
|
// - 選項 B:cluster 改為「pass-through 到 local agent」— 雲端只管叢集定義,
|
||||||
|
// 實際分派與 pipeline 由 local agent 自行完成
|
||||||
|
// - 選項 C:雲端只保留叢集 CRUD,不做 dispatch(最輕量)
|
||||||
|
//
|
||||||
|
// 來源:POC `edge-ai-platform/server/internal/cluster/types.go`(去除 driver.InferenceResult 相依)。
|
||||||
|
package cluster
|
||||||
|
|
||||||
|
// Default dispatch weights per chip type。
|
||||||
|
const (
|
||||||
|
DefaultWeightKL720 = 3
|
||||||
|
DefaultWeightKL520 = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// MaxClusterSize 單一叢集允許的最大裝置數。
|
||||||
|
const MaxClusterSize = 8
|
||||||
|
|
||||||
|
// ClusterStatus 叢集當前狀態。
|
||||||
|
type ClusterStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ClusterIdle ClusterStatus = "idle"
|
||||||
|
ClusterInferencing ClusterStatus = "inferencing"
|
||||||
|
ClusterDegraded ClusterStatus = "degraded"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MemberStatus 單個裝置在叢集內的狀態。
|
||||||
|
type MemberStatus string
|
||||||
|
|
||||||
|
const (
|
||||||
|
MemberActive MemberStatus = "active"
|
||||||
|
MemberDegraded MemberStatus = "degraded"
|
||||||
|
MemberRemoved MemberStatus = "removed"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceMember 參與叢集的一個裝置。
|
||||||
|
type DeviceMember struct {
|
||||||
|
DeviceID string `json:"deviceId"`
|
||||||
|
Weight int `json:"weight"`
|
||||||
|
Status MemberStatus `json:"status"`
|
||||||
|
DeviceName string `json:"deviceName,omitempty"`
|
||||||
|
DeviceType string `json:"deviceType,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cluster 叢集主體。
|
||||||
|
type Cluster struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Devices []DeviceMember `json:"devices"`
|
||||||
|
ModelID string `json:"modelId,omitempty"`
|
||||||
|
Status ClusterStatus `json:"status"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterFlashProgress 叢集 flash 進度回報(單一裝置)。
|
||||||
|
type ClusterFlashProgress struct {
|
||||||
|
DeviceID string `json:"deviceId"`
|
||||||
|
Percent int `json:"percent"`
|
||||||
|
Stage string `json:"stage"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Error string `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClusterResult 叢集推論結果的薄包裝;完整 schema(含原始推論 metadata)
|
||||||
|
// 待 Phase 1 推論 pipeline 重建時補齊。雛形保留 placeholder 讓 API schema 可對齊。
|
||||||
|
//
|
||||||
|
// TODO(B5+):補 InferenceResult 欄位(對應 POC `driver.InferenceResult`)。
|
||||||
|
type ClusterResult struct {
|
||||||
|
ClusterID string `json:"clusterId"`
|
||||||
|
FrameIndex int64 `json:"frameIndex"`
|
||||||
|
// Payload 為雛形佔位 — 真實欄位待補。
|
||||||
|
Payload map[string]any `json:"payload,omitempty"`
|
||||||
|
}
|
||||||
258
visionA-backend/internal/config/config.go
Normal file
258
visionA-backend/internal/config/config.go
Normal file
@ -0,0 +1,258 @@
|
|||||||
|
// Package config 定義 visionA-backend 的組態結構,對齊 TDD §2.10。
|
||||||
|
//
|
||||||
|
// 雛形遵循 12-Factor App:所有可變設定皆透過環境變數注入,不寫死在程式碼裡。
|
||||||
|
// `api-server` 與 `remote-proxy` 共享同一份 Config;各自只消費自己需要的欄位。
|
||||||
|
package config
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// Config 是整個 visionA-backend 的環境設定。
|
||||||
|
//
|
||||||
|
// 所有欄位皆由 Load() 從環境變數讀取並套用預設值。
|
||||||
|
// 欄位命名對齊 TDD §2.10;新增欄位時請同步更新 `.env.example`(待 B6)。
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig
|
||||||
|
Session SessionConfig
|
||||||
|
Auth AuthConfig
|
||||||
|
OIDC OIDCConfig
|
||||||
|
UserSession UserSessionConfig
|
||||||
|
Storage StorageConfig
|
||||||
|
Model ModelConfig
|
||||||
|
Tunnel TunnelConfig
|
||||||
|
Logger LoggerConfig
|
||||||
|
CORS CORSConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig 控制 HTTP listener 的位址與埠號。
|
||||||
|
//
|
||||||
|
// api-server 端使用 Port 提供 REST / WebSocket;
|
||||||
|
// remote-proxy 端使用 TunnelPort(面向 local agent)與 InternalPort(面向 api-server)。
|
||||||
|
//
|
||||||
|
// Port 預設為 3721 — 對齊 local-tool 的 base URL,這樣 local-tool 前端切到雲端版時
|
||||||
|
// base URL 可以維持一致,降低前端的 dev 流程切換成本(B4 決定)。
|
||||||
|
type ServerConfig struct {
|
||||||
|
Host string // VISIONA_HOST,預設 "0.0.0.0"
|
||||||
|
Port int // VISIONA_API_PORT,預設 3721(對齊 local-tool)
|
||||||
|
TunnelPort int // VISIONA_TUNNEL_PORT,預設 3800
|
||||||
|
InternalPort int // VISIONA_PROXY_INTERNAL_PORT,預設 3801
|
||||||
|
|
||||||
|
// RelayPublicURL 是 agent 連 tunnel 用的對外可達 URL(通常是 wss://.../tunnel/connect
|
||||||
|
// 的 origin 部分,例:wss://relay.visionA.cloud)。
|
||||||
|
// AB11 新增:/api/pairing/exchange 會把這個值回傳給 agent。
|
||||||
|
// 雛形預設為空 — handler 會 fallback 到 placeholder `wss://relay.visionA.cloud`。
|
||||||
|
RelayPublicURL string
|
||||||
|
|
||||||
|
// SeedDemoData 控制 api-server 啟動時是否塞入示範用 device + model + pairing token。
|
||||||
|
// 預設 false;本機開發或 demo 時可設 VISIONA_SEED_DEMO_DATA=true 開啟,
|
||||||
|
// 方便前端不必跑完整 pairing 流程就能看到資料。
|
||||||
|
SeedDemoData bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SessionConfig 控制 SessionStore 的實作選擇與連線資訊。
|
||||||
|
//
|
||||||
|
// Backend:
|
||||||
|
// - "inmemory" — remote-proxy 端持有 yamux session 的唯一來源
|
||||||
|
// - "proxy-client" — api-server 端透過 internal HTTP 查詢 remote-proxy
|
||||||
|
type SessionConfig struct {
|
||||||
|
Backend string // VISIONA_SESSION_BACKEND,預設 "inmemory"
|
||||||
|
ProxyInternalURL string // VISIONA_PROXY_INTERNAL_URL,預設 "http://localhost:3801"
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthConfig 控制雛形專用的 user fallback 與 pairing token。
|
||||||
|
//
|
||||||
|
// OB5(2026-04-26)起認證走 OIDC(OIDCConfig);
|
||||||
|
// Phase 0.7(2026-05-01)security audit 移除了 api.Deps.StaticUserID handler fallback
|
||||||
|
// (見 .autoflow/05-implementation/review/phase-0.7-security-audit.md C1)。
|
||||||
|
// 此處的 StaticUserID 欄位**僅供 dev seed(VISIONA_SEED_DEMO_DATA=true)與 unit test
|
||||||
|
// fixture 讀取使用**,不再注入 api.Deps、不影響 stage/prod 認證行為。
|
||||||
|
type AuthConfig struct {
|
||||||
|
// StaticUserID — Deprecated for routing/auth use. 僅供 dev seed / unit test。
|
||||||
|
// 見 internal/api/api.go 的 Deps 註解;stage/prod 留空無影響。
|
||||||
|
StaticUserID string // VISIONA_STATIC_USER_ID,預設 "demo-user"(dev seed only)
|
||||||
|
PairingToken string // VISIONA_PAIRING_TOKEN,格式必須為 vAc_ + 32 hex
|
||||||
|
SigningSecret string // VISIONA_STORAGE_SIGNING_SECRET,presigned URL HMAC secret
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCConfig 控制 OpenID Connect 登入流程(BFF 模式)。
|
||||||
|
//
|
||||||
|
// 對齊 oidc-tdd.md §13.1 + ADR-010 + ADR-011 + ADR-013。
|
||||||
|
// OB5 起 OIDC 是唯一認證路徑;A1 起支援 public PKCE-only client。
|
||||||
|
type OIDCConfig struct {
|
||||||
|
// IssuerURL 是 OIDC IdP 的 issuer(不帶結尾斜線),例如
|
||||||
|
// dev: http://localhost:5050
|
||||||
|
// prod: https://members.innovedus.com
|
||||||
|
// 對齊 VISIONA_OIDC_ISSUER_URL。
|
||||||
|
IssuerURL string
|
||||||
|
|
||||||
|
// ClientID 是 visionA 在 IdP 註冊的 OAuth client id(confidential 或 public 皆可)。
|
||||||
|
// 對齊 VISIONA_OIDC_CLIENT_ID。
|
||||||
|
ClientID string
|
||||||
|
|
||||||
|
// ClientSecret 為**選填**(A1, 2026-05-01):
|
||||||
|
// - 有值 → confidential client mode(client_secret + PKCE 雙保險)
|
||||||
|
// - 留空 → PKCE-only public client mode(純依靠 PKCE 防 code interception)
|
||||||
|
// 兩種 mode 由 IdP 決定,visionA-backend 都支援(見 ADR-013)。
|
||||||
|
// **禁止 commit 進 repo**;對齊 VISIONA_OIDC_CLIENT_SECRET。
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// RedirectURL 是 visionA-backend 的 callback URL,必須與 IdP 註冊值完全一致。
|
||||||
|
// dev: http://localhost:3721/api/auth/callback
|
||||||
|
// prod: https://api.visiona.cloud/api/auth/callback
|
||||||
|
// 對齊 VISIONA_OIDC_REDIRECT_URL。
|
||||||
|
RedirectURL string
|
||||||
|
|
||||||
|
// PostLoginURL 是 callback 完成後 302 回 frontend 的 base URL。
|
||||||
|
// dev: http://localhost:3000
|
||||||
|
// prod: https://app.visiona.cloud
|
||||||
|
// 對齊 VISIONA_FRONTEND_URL(沿用 oidc-tdd.md §13.1 命名)。
|
||||||
|
PostLoginURL string
|
||||||
|
|
||||||
|
// ServiceClientID 是「visionA-backend 以服務身份呼叫 MC API」用的 client id,
|
||||||
|
// 預留給未來 client_credentials grant flow(例如查詢使用者組織、推送通知等)。
|
||||||
|
//
|
||||||
|
// **A1 階段不啟用**:Validate() 不檢查、main.go 不 wire;只先把 config 鉤子留好,
|
||||||
|
// 之後接時不必再改 OIDCConfig schema。對齊 VISIONA_OIDC_SERVICE_CLIENT_ID。
|
||||||
|
ServiceClientID string
|
||||||
|
|
||||||
|
// ServiceClientSecret 是 service client(client_credentials grant)的 secret。
|
||||||
|
// 與 ServiceClientID 配對使用;同樣 A1 階段不啟用、Validate() 不檢查。
|
||||||
|
// **禁止 commit 進 repo**;對齊 VISIONA_OIDC_SERVICE_CLIENT_SECRET。
|
||||||
|
ServiceClientSecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserSessionConfig 控制 OIDC 登入後在 browser 端建立的 cookie session。
|
||||||
|
//
|
||||||
|
// 注意:與既有 SessionConfig(tunnel session 用)刻意分開,避免命名混淆。
|
||||||
|
// 對齊 oidc-tdd.md §5、§13.1。
|
||||||
|
type UserSessionConfig struct {
|
||||||
|
// Secret 是 cookie HMAC-SHA256 簽章金鑰;應為至少 32 byte 隨機字串。
|
||||||
|
// 對齊 VISIONA_SESSION_SECRET。
|
||||||
|
Secret string
|
||||||
|
|
||||||
|
// CookieName 預設 "visiona_session"。
|
||||||
|
// 對齊 VISIONA_SESSION_COOKIE_NAME。
|
||||||
|
CookieName string
|
||||||
|
|
||||||
|
// CookieDomain:dev 留空(host-only cookie),prod 設 ".visiona.cloud"。
|
||||||
|
// 對齊 VISIONA_SESSION_COOKIE_DOMAIN。
|
||||||
|
CookieDomain string
|
||||||
|
|
||||||
|
// CookieSecure 控制 Secure flag。dev=false(http),prod=true(https)。
|
||||||
|
// 對齊 VISIONA_SESSION_COOKIE_SECURE。
|
||||||
|
CookieSecure bool
|
||||||
|
|
||||||
|
// AbsoluteTTL 是 session 的最長存活時間(從 Create 起算)。預設 168h(7 天)。
|
||||||
|
// 對齊 VISIONA_SESSION_ABSOLUTE_TTL。
|
||||||
|
AbsoluteTTL time.Duration
|
||||||
|
|
||||||
|
// IdleTTL 是 session 的閒置存活時間(從 LastSeenAt 起算)。預設 24h。
|
||||||
|
// 對齊 VISIONA_SESSION_IDLE_TTL。
|
||||||
|
IdleTTL time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageConfig 控制儲存層實作(LocalFS / S3)與路徑。
|
||||||
|
type StorageConfig struct {
|
||||||
|
Backend string // VISIONA_STORAGE_BACKEND,預設 "localfs"
|
||||||
|
RootDir string // VISIONA_STORAGE_LOCALFS_ROOT,預設 "./data/storage"
|
||||||
|
BaseURL string // VISIONA_STORAGE_LOCALFS_BASE_URL,預設 "http://localhost:3721/storage"(對齊 api-server port)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelConfig 針對模型資源的驗證限制(大小等)。
|
||||||
|
type ModelConfig struct {
|
||||||
|
// MaxSizeMB 是允許上傳的單一模型檔案大小上限(MB)。
|
||||||
|
// PRD §8.4 規範 Phase 0 為 100 MB;可由 VISIONA_MODEL_MAX_SIZE_MB 覆寫。
|
||||||
|
MaxSizeMB int
|
||||||
|
}
|
||||||
|
|
||||||
|
// TunnelConfig 控制 tunnel 心跳與掉線判定閾值,對齊 tunnel.md §4.2。
|
||||||
|
type TunnelConfig struct {
|
||||||
|
// HeartbeatInterval 為 yamux KeepAliveInterval 值。預設 10s。
|
||||||
|
HeartbeatInterval time.Duration
|
||||||
|
// IdleTimeout 為判定對端失聯的時間。預設 30s(= 3 次心跳未回)。
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoggerConfig 控制結構化 logger 的輸出等級。
|
||||||
|
type LoggerConfig struct {
|
||||||
|
Level string // VISIONA_LOG_LEVEL:debug / info / warn / error,預設 "info"
|
||||||
|
}
|
||||||
|
|
||||||
|
// CORSConfig 控制 api-server 對瀏覽器的 CORS 白名單。
|
||||||
|
//
|
||||||
|
// AllowedOrigins 為逗號分隔字串解析後的 slice;
|
||||||
|
// 空時 api.Deps.validate() 會 fallback 到 http://localhost:3000(前端 dev server)。
|
||||||
|
type CORSConfig struct {
|
||||||
|
AllowedOrigins []string // VISIONA_CORS_ALLOWED_ORIGINS(逗號分隔)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate 在 Load() 之後檢查交叉依賴與必填欄位。
|
||||||
|
//
|
||||||
|
// OB5 起 OIDC 是唯一認證路徑,所有 OIDC 必填欄位永遠都要非空:
|
||||||
|
// - IssuerURL / ClientID / RedirectURL / PostLoginURL
|
||||||
|
// - UserSession.Secret(cookie HMAC 簽章)
|
||||||
|
//
|
||||||
|
// ClientSecret 為**選填**(A1, 2026-05-01):
|
||||||
|
// - 有值 → confidential client mode(標準 OAuth + PKCE 雙保險)
|
||||||
|
// - 留空 → PKCE-only public client mode(依靠 PKCE 防 code interception)
|
||||||
|
//
|
||||||
|
// 兩種 mode 由 IdP 決定,visionA 都支援(見 ADR-013、oidc-tdd.md §13.1)。
|
||||||
|
//
|
||||||
|
// ServiceClientID / ServiceClientSecret 為 client_credentials grant 預留欄位,
|
||||||
|
// A1 階段不啟用、不檢查;之後若接服務間 API 呼叫再補 Validate。
|
||||||
|
//
|
||||||
|
// 缺任何**必填**項 → 回 *MissingEnvError,main.go 啟動時 fatal log 退出。
|
||||||
|
// 維持單一 error 而非列表 — caller 只是 fail-fast 紀錄,不需要結構化處理。
|
||||||
|
func (c *Config) Validate() error {
|
||||||
|
missing := make([]string, 0, 5)
|
||||||
|
if c.OIDC.IssuerURL == "" {
|
||||||
|
missing = append(missing, "VISIONA_OIDC_ISSUER_URL")
|
||||||
|
}
|
||||||
|
if c.OIDC.ClientID == "" {
|
||||||
|
missing = append(missing, "VISIONA_OIDC_CLIENT_ID")
|
||||||
|
}
|
||||||
|
// ClientSecret 為選填(public PKCE-only client 留空)— 不檢查。
|
||||||
|
if c.OIDC.RedirectURL == "" {
|
||||||
|
missing = append(missing, "VISIONA_OIDC_REDIRECT_URL")
|
||||||
|
}
|
||||||
|
if c.OIDC.PostLoginURL == "" {
|
||||||
|
missing = append(missing, "VISIONA_FRONTEND_URL")
|
||||||
|
}
|
||||||
|
if c.UserSession.Secret == "" {
|
||||||
|
missing = append(missing, "VISIONA_SESSION_SECRET")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return &MissingEnvError{Vars: missing}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MissingEnvError 表示 OIDC 必填環境變數缺少(OB5 起永遠檢查)。
|
||||||
|
type MissingEnvError struct {
|
||||||
|
Vars []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *MissingEnvError) Error() string {
|
||||||
|
return "config: OIDC enabled but required env vars are missing: " + joinStrings(e.Vars, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// joinStrings 是 strings.Join 的本地版本,避免單純為了 join 引入 strings package。
|
||||||
|
func joinStrings(parts []string, sep string) string {
|
||||||
|
switch len(parts) {
|
||||||
|
case 0:
|
||||||
|
return ""
|
||||||
|
case 1:
|
||||||
|
return parts[0]
|
||||||
|
}
|
||||||
|
n := len(sep) * (len(parts) - 1)
|
||||||
|
for _, p := range parts {
|
||||||
|
n += len(p)
|
||||||
|
}
|
||||||
|
out := make([]byte, 0, n)
|
||||||
|
out = append(out, parts[0]...)
|
||||||
|
for _, p := range parts[1:] {
|
||||||
|
out = append(out, sep...)
|
||||||
|
out = append(out, p...)
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
131
visionA-backend/internal/config/load.go
Normal file
131
visionA-backend/internal/config/load.go
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Load 從環境變數讀取並組出一個 Config。
|
||||||
|
//
|
||||||
|
// 所有欄位皆有預設值(雛形便利),因此 Load 不會回傳 error;
|
||||||
|
// 未來加入必填欄位時(例如 Phase 1 的 DB URL),應改為回傳 error。
|
||||||
|
func Load() *Config {
|
||||||
|
return &Config{
|
||||||
|
Server: ServerConfig{
|
||||||
|
Host: getEnvString("VISIONA_HOST", "0.0.0.0"),
|
||||||
|
Port: getEnvInt("VISIONA_API_PORT", 3721),
|
||||||
|
TunnelPort: getEnvInt("VISIONA_TUNNEL_PORT", 3800),
|
||||||
|
InternalPort: getEnvInt("VISIONA_PROXY_INTERNAL_PORT", 3801),
|
||||||
|
RelayPublicURL: getEnvString("VISIONA_RELAY_PUBLIC_URL", ""),
|
||||||
|
SeedDemoData: getEnvBool("VISIONA_SEED_DEMO_DATA", false),
|
||||||
|
},
|
||||||
|
Session: SessionConfig{
|
||||||
|
Backend: getEnvString("VISIONA_SESSION_BACKEND", "inmemory"),
|
||||||
|
ProxyInternalURL: getEnvString("VISIONA_PROXY_INTERNAL_URL", "http://localhost:3801"),
|
||||||
|
},
|
||||||
|
Auth: AuthConfig{
|
||||||
|
// Phase 0.7 security fix C1:VISIONA_STATIC_USER_ID 僅供 dev seed / unit test 用,
|
||||||
|
// stage/prod 留空無影響;不再注入 api.Deps(見 internal/api/api.go Deps 註解)。
|
||||||
|
StaticUserID: getEnvString("VISIONA_STATIC_USER_ID", "demo-user"),
|
||||||
|
PairingToken: getEnvString("VISIONA_PAIRING_TOKEN", ""),
|
||||||
|
SigningSecret: getEnvString("VISIONA_STORAGE_SIGNING_SECRET", "dev-signing-secret-do-not-use-in-prod"),
|
||||||
|
},
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
IssuerURL: getEnvString("VISIONA_OIDC_ISSUER_URL", ""),
|
||||||
|
ClientID: getEnvString("VISIONA_OIDC_CLIENT_ID", ""),
|
||||||
|
ClientSecret: getEnvString("VISIONA_OIDC_CLIENT_SECRET", ""),
|
||||||
|
RedirectURL: getEnvString("VISIONA_OIDC_REDIRECT_URL", ""),
|
||||||
|
PostLoginURL: getEnvString("VISIONA_FRONTEND_URL", ""),
|
||||||
|
// A1:client_credentials grant 預留欄位,留空表「不啟用 service client」。
|
||||||
|
ServiceClientID: getEnvString("VISIONA_OIDC_SERVICE_CLIENT_ID", ""),
|
||||||
|
ServiceClientSecret: getEnvString("VISIONA_OIDC_SERVICE_CLIENT_SECRET", ""),
|
||||||
|
},
|
||||||
|
UserSession: UserSessionConfig{
|
||||||
|
Secret: getEnvString("VISIONA_SESSION_SECRET", ""),
|
||||||
|
CookieName: getEnvString("VISIONA_SESSION_COOKIE_NAME", "visiona_session"),
|
||||||
|
CookieDomain: getEnvString("VISIONA_SESSION_COOKIE_DOMAIN", ""),
|
||||||
|
CookieSecure: getEnvBool("VISIONA_SESSION_COOKIE_SECURE", false),
|
||||||
|
AbsoluteTTL: getEnvDuration("VISIONA_SESSION_ABSOLUTE_TTL", 168*time.Hour),
|
||||||
|
IdleTTL: getEnvDuration("VISIONA_SESSION_IDLE_TTL", 24*time.Hour),
|
||||||
|
},
|
||||||
|
Storage: StorageConfig{
|
||||||
|
Backend: getEnvString("VISIONA_STORAGE_BACKEND", "localfs"),
|
||||||
|
RootDir: getEnvString("VISIONA_STORAGE_LOCALFS_ROOT", "./data/storage"),
|
||||||
|
BaseURL: getEnvString("VISIONA_STORAGE_LOCALFS_BASE_URL", "http://localhost:3721/storage"),
|
||||||
|
},
|
||||||
|
Model: ModelConfig{
|
||||||
|
MaxSizeMB: getEnvInt("VISIONA_MODEL_MAX_SIZE_MB", 100),
|
||||||
|
},
|
||||||
|
Tunnel: TunnelConfig{
|
||||||
|
HeartbeatInterval: getEnvDuration("VISIONA_TUNNEL_HEARTBEAT_INTERVAL", 10*time.Second),
|
||||||
|
IdleTimeout: getEnvDuration("VISIONA_TUNNEL_IDLE_TIMEOUT", 30*time.Second),
|
||||||
|
},
|
||||||
|
Logger: LoggerConfig{
|
||||||
|
Level: getEnvString("VISIONA_LOG_LEVEL", "info"),
|
||||||
|
},
|
||||||
|
CORS: CORSConfig{
|
||||||
|
AllowedOrigins: getEnvStringSlice("VISIONA_CORS_ALLOWED_ORIGINS", nil),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvStringSlice 從環境變數取逗號分隔字串,拆成 slice。
|
||||||
|
// 每段都會 TrimSpace;空段會被過濾。若環境變數未設定或為空,回傳 fallback。
|
||||||
|
func getEnvStringSlice(key string, fallback []string) []string {
|
||||||
|
v, ok := os.LookupEnv(key)
|
||||||
|
if !ok || v == "" {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
parts := strings.Split(v, ",")
|
||||||
|
result := make([]string, 0, len(parts))
|
||||||
|
for _, p := range parts {
|
||||||
|
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||||
|
result = append(result, trimmed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(result) == 0 {
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvString 從環境變數取字串,不存在或為空則回傳預設值。
|
||||||
|
func getEnvString(key, fallback string) string {
|
||||||
|
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvInt 從環境變數取整數,若無法解析則回傳預設值。
|
||||||
|
func getEnvInt(key string, fallback int) int {
|
||||||
|
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvDuration 從環境變數取 time.Duration(支援 "10s"、"1m" 等格式)。
|
||||||
|
func getEnvDuration(key string, fallback time.Duration) time.Duration {
|
||||||
|
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||||
|
if d, err := time.ParseDuration(v); err == nil {
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnvBool 從環境變數取布林值(接受 "true"/"false"/"1"/"0",大小寫不敏感)。
|
||||||
|
// 解析失敗或未設定回傳 fallback。
|
||||||
|
func getEnvBool(key string, fallback bool) bool {
|
||||||
|
if v, ok := os.LookupEnv(key); ok && v != "" {
|
||||||
|
if b, err := strconv.ParseBool(v); err == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
267
visionA-backend/internal/config/load_test.go
Normal file
267
visionA-backend/internal/config/load_test.go
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLoad_Defaults(t *testing.T) {
|
||||||
|
// Arrange:清掉所有相關 env(t.Setenv 自動還原)
|
||||||
|
for _, k := range []string{
|
||||||
|
"VISIONA_HOST", "VISIONA_API_PORT", "VISIONA_TUNNEL_PORT", "VISIONA_PROXY_INTERNAL_PORT",
|
||||||
|
"VISIONA_SESSION_BACKEND", "VISIONA_PROXY_INTERNAL_URL",
|
||||||
|
"VISIONA_AUTH_TYPE", "VISIONA_STATIC_USER_ID", "VISIONA_PAIRING_TOKEN",
|
||||||
|
"VISIONA_STORAGE_BACKEND", "VISIONA_STORAGE_LOCALFS_ROOT",
|
||||||
|
"VISIONA_MODEL_MAX_SIZE_MB",
|
||||||
|
"VISIONA_TUNNEL_HEARTBEAT_INTERVAL", "VISIONA_TUNNEL_IDLE_TIMEOUT",
|
||||||
|
"VISIONA_LOG_LEVEL",
|
||||||
|
} {
|
||||||
|
t.Setenv(k, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Act
|
||||||
|
cfg := Load()
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
assert.Equal(t, "0.0.0.0", cfg.Server.Host)
|
||||||
|
// Port 預設改為 3721 — 對齊 local-tool(B4)。
|
||||||
|
assert.Equal(t, 3721, cfg.Server.Port)
|
||||||
|
assert.Equal(t, 3800, cfg.Server.TunnelPort)
|
||||||
|
assert.Equal(t, 3801, cfg.Server.InternalPort)
|
||||||
|
assert.False(t, cfg.Server.SeedDemoData, "預設不 seed demo data")
|
||||||
|
assert.Equal(t, "inmemory", cfg.Session.Backend)
|
||||||
|
assert.Equal(t, "http://localhost:3801", cfg.Session.ProxyInternalURL)
|
||||||
|
assert.Equal(t, "demo-user", cfg.Auth.StaticUserID)
|
||||||
|
assert.Equal(t, "localfs", cfg.Storage.Backend)
|
||||||
|
assert.Equal(t, "./data/storage", cfg.Storage.RootDir)
|
||||||
|
assert.Equal(t, 100, cfg.Model.MaxSizeMB)
|
||||||
|
assert.Equal(t, 10*time.Second, cfg.Tunnel.HeartbeatInterval)
|
||||||
|
assert.Equal(t, 30*time.Second, cfg.Tunnel.IdleTimeout)
|
||||||
|
assert.Equal(t, "info", cfg.Logger.Level)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad_EnvOverrides(t *testing.T) {
|
||||||
|
t.Setenv("VISIONA_API_PORT", "8080")
|
||||||
|
t.Setenv("VISIONA_STATIC_USER_ID", "custom-user")
|
||||||
|
t.Setenv("VISIONA_MODEL_MAX_SIZE_MB", "500")
|
||||||
|
t.Setenv("VISIONA_TUNNEL_HEARTBEAT_INTERVAL", "5s")
|
||||||
|
t.Setenv("VISIONA_LOG_LEVEL", "debug")
|
||||||
|
|
||||||
|
cfg := Load()
|
||||||
|
|
||||||
|
assert.Equal(t, 8080, cfg.Server.Port)
|
||||||
|
assert.Equal(t, "custom-user", cfg.Auth.StaticUserID)
|
||||||
|
assert.Equal(t, 500, cfg.Model.MaxSizeMB)
|
||||||
|
assert.Equal(t, 5*time.Second, cfg.Tunnel.HeartbeatInterval)
|
||||||
|
assert.Equal(t, "debug", cfg.Logger.Level)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoad_InvalidIntFallback(t *testing.T) {
|
||||||
|
t.Setenv("VISIONA_API_PORT", "not-a-number")
|
||||||
|
cfg := Load()
|
||||||
|
assert.Equal(t, 3721, cfg.Server.Port, "無法解析時應 fallback 到預設值(B4 改為 3721)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_SeedDemoData 驗證 VISIONA_SEED_DEMO_DATA env 的解析行為。
|
||||||
|
func TestLoad_SeedDemoData(t *testing.T) {
|
||||||
|
t.Setenv("VISIONA_SEED_DEMO_DATA", "true")
|
||||||
|
cfg := Load()
|
||||||
|
assert.True(t, cfg.Server.SeedDemoData)
|
||||||
|
|
||||||
|
t.Setenv("VISIONA_SEED_DEMO_DATA", "false")
|
||||||
|
cfg = Load()
|
||||||
|
assert.False(t, cfg.Server.SeedDemoData)
|
||||||
|
|
||||||
|
// 無法解析時 fallback 到預設 false
|
||||||
|
t.Setenv("VISIONA_SEED_DEMO_DATA", "not-a-bool")
|
||||||
|
cfg = Load()
|
||||||
|
assert.False(t, cfg.Server.SeedDemoData, "無法解析時應 fallback 到預設值")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_OIDCDefaults 驗證未設定任何 VISIONA_OIDC_* 時,OIDC 欄位為空字串。
|
||||||
|
//
|
||||||
|
// OB5 起 OIDC.Enabled 已移除(OIDC 是唯一認證路徑);空字串就是「未設定」,
|
||||||
|
// 此時 Validate() 會回 MissingEnvError,main.go 啟動時 fatal log 退出。
|
||||||
|
func TestLoad_OIDCDefaults(t *testing.T) {
|
||||||
|
for _, k := range []string{
|
||||||
|
"VISIONA_OIDC_ISSUER_URL", "VISIONA_OIDC_CLIENT_ID",
|
||||||
|
"VISIONA_OIDC_CLIENT_SECRET", "VISIONA_OIDC_REDIRECT_URL", "VISIONA_FRONTEND_URL",
|
||||||
|
"VISIONA_OIDC_SERVICE_CLIENT_ID", "VISIONA_OIDC_SERVICE_CLIENT_SECRET",
|
||||||
|
"VISIONA_SESSION_SECRET", "VISIONA_SESSION_COOKIE_NAME", "VISIONA_SESSION_COOKIE_DOMAIN",
|
||||||
|
"VISIONA_SESSION_COOKIE_SECURE", "VISIONA_SESSION_ABSOLUTE_TTL", "VISIONA_SESSION_IDLE_TTL",
|
||||||
|
} {
|
||||||
|
t.Setenv(k, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := Load()
|
||||||
|
|
||||||
|
assert.Empty(t, cfg.OIDC.IssuerURL)
|
||||||
|
assert.Empty(t, cfg.OIDC.ClientID)
|
||||||
|
assert.Empty(t, cfg.OIDC.ClientSecret)
|
||||||
|
assert.Empty(t, cfg.OIDC.RedirectURL)
|
||||||
|
assert.Empty(t, cfg.OIDC.PostLoginURL)
|
||||||
|
assert.Empty(t, cfg.OIDC.ServiceClientID, "ServiceClientID 預設留空(A1:未啟用)")
|
||||||
|
assert.Empty(t, cfg.OIDC.ServiceClientSecret, "ServiceClientSecret 預設留空(A1:未啟用)")
|
||||||
|
|
||||||
|
assert.Empty(t, cfg.UserSession.Secret, "雛形 dev 預設不附 secret,由 caller 注入或啟動失敗")
|
||||||
|
assert.Equal(t, "visiona_session", cfg.UserSession.CookieName)
|
||||||
|
assert.Empty(t, cfg.UserSession.CookieDomain)
|
||||||
|
assert.False(t, cfg.UserSession.CookieSecure)
|
||||||
|
assert.Equal(t, 168*time.Hour, cfg.UserSession.AbsoluteTTL)
|
||||||
|
assert.Equal(t, 24*time.Hour, cfg.UserSession.IdleTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_OIDC_ClientSecretOptional:A1(2026-05-01)— 缺 ClientSecret 不再回 MissingEnvError。
|
||||||
|
//
|
||||||
|
// 模擬 Stage 用的 public PKCE-only client(MC 給的 b8093fea... 沒有 client_secret)。
|
||||||
|
func TestLoad_OIDC_ClientSecretOptional(t *testing.T) {
|
||||||
|
t.Setenv("VISIONA_OIDC_ISSUER_URL", "https://stage-9527.innovedus.com:7850/")
|
||||||
|
t.Setenv("VISIONA_OIDC_CLIENT_ID", "b8093fea1a504a5d8f0e04bee9f78f2e")
|
||||||
|
t.Setenv("VISIONA_OIDC_CLIENT_SECRET", "") // 故意留空 — public client
|
||||||
|
t.Setenv("VISIONA_OIDC_REDIRECT_URL", "https://stage-9527.innovedus.com:9527/api/auth/callback")
|
||||||
|
t.Setenv("VISIONA_FRONTEND_URL", "https://stage-9527.innovedus.com:9527")
|
||||||
|
t.Setenv("VISIONA_SESSION_SECRET", "32-byte-or-longer-random-secret-aaaa")
|
||||||
|
|
||||||
|
cfg := Load()
|
||||||
|
assert.Empty(t, cfg.OIDC.ClientSecret, "public client mode:ClientSecret 應為空字串")
|
||||||
|
assert.NoError(t, cfg.Validate(), "ClientSecret 為空不應觸發 MissingEnvError")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_OIDC_ServiceClientFields:A1 預留 client_credentials grant 兩個欄位能正確讀取。
|
||||||
|
// 測試固定值故意用顯而易見的 fake — 不要貼任何環境的真實 client_id / secret 進測試。
|
||||||
|
func TestLoad_OIDC_ServiceClientFields(t *testing.T) {
|
||||||
|
const fakeServiceID = "fake-service-client-id-for-test"
|
||||||
|
const fakeServiceSecret = "fake-service-client-secret-for-test"
|
||||||
|
|
||||||
|
t.Setenv("VISIONA_OIDC_SERVICE_CLIENT_ID", fakeServiceID)
|
||||||
|
t.Setenv("VISIONA_OIDC_SERVICE_CLIENT_SECRET", fakeServiceSecret)
|
||||||
|
|
||||||
|
cfg := Load()
|
||||||
|
assert.Equal(t, fakeServiceID, cfg.OIDC.ServiceClientID)
|
||||||
|
assert.Equal(t, fakeServiceSecret, cfg.OIDC.ServiceClientSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_OIDCAllSet 驗證 OIDC env vars 設定後能正確讀取。
|
||||||
|
func TestLoad_OIDCAllSet(t *testing.T) {
|
||||||
|
t.Setenv("VISIONA_OIDC_ISSUER_URL", "http://localhost:5050")
|
||||||
|
t.Setenv("VISIONA_OIDC_CLIENT_ID", "visionA")
|
||||||
|
t.Setenv("VISIONA_OIDC_CLIENT_SECRET", "secret")
|
||||||
|
t.Setenv("VISIONA_OIDC_REDIRECT_URL", "http://localhost:3721/api/auth/callback")
|
||||||
|
t.Setenv("VISIONA_FRONTEND_URL", "http://localhost:3000")
|
||||||
|
t.Setenv("VISIONA_SESSION_SECRET", "32-byte-or-longer-random-secret-aaaa")
|
||||||
|
t.Setenv("VISIONA_SESSION_COOKIE_SECURE", "true")
|
||||||
|
t.Setenv("VISIONA_SESSION_ABSOLUTE_TTL", "72h")
|
||||||
|
t.Setenv("VISIONA_SESSION_IDLE_TTL", "12h")
|
||||||
|
|
||||||
|
cfg := Load()
|
||||||
|
|
||||||
|
assert.Equal(t, "http://localhost:5050", cfg.OIDC.IssuerURL)
|
||||||
|
assert.Equal(t, "visionA", cfg.OIDC.ClientID)
|
||||||
|
assert.Equal(t, "secret", cfg.OIDC.ClientSecret)
|
||||||
|
assert.Equal(t, "http://localhost:3721/api/auth/callback", cfg.OIDC.RedirectURL)
|
||||||
|
assert.Equal(t, "http://localhost:3000", cfg.OIDC.PostLoginURL)
|
||||||
|
assert.Equal(t, "32-byte-or-longer-random-secret-aaaa", cfg.UserSession.Secret)
|
||||||
|
assert.True(t, cfg.UserSession.CookieSecure)
|
||||||
|
assert.Equal(t, 72*time.Hour, cfg.UserSession.AbsoluteTTL)
|
||||||
|
assert.Equal(t, 12*time.Hour, cfg.UserSession.IdleTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestConfig_Validate_MissingFields 驗證 OIDC 必填欄位缺失時回 MissingEnvError。
|
||||||
|
//
|
||||||
|
// A1(2026-05-01):ClientSecret 改為選填,已從必填清單移除;剩 5 項必填。
|
||||||
|
func TestConfig_Validate_MissingFields(t *testing.T) {
|
||||||
|
cfg := &Config{} // 全部欄位 zero value
|
||||||
|
err := cfg.Validate()
|
||||||
|
require.Error(t, err)
|
||||||
|
|
||||||
|
var missErr *MissingEnvError
|
||||||
|
require.ErrorAs(t, err, &missErr, "錯誤型別應可被 errors.As 解出")
|
||||||
|
// 應列出 5 個必填欄位(不含 ClientSecret)
|
||||||
|
assert.ElementsMatch(t, []string{
|
||||||
|
"VISIONA_OIDC_ISSUER_URL",
|
||||||
|
"VISIONA_OIDC_CLIENT_ID",
|
||||||
|
"VISIONA_OIDC_REDIRECT_URL",
|
||||||
|
"VISIONA_FRONTEND_URL",
|
||||||
|
"VISIONA_SESSION_SECRET",
|
||||||
|
}, missErr.Vars)
|
||||||
|
assert.NotContains(t, missErr.Vars, "VISIONA_OIDC_CLIENT_SECRET",
|
||||||
|
"A1:ClientSecret 為選填,不應出現在必填缺失清單")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_ConfidentialClient:完整 confidential client(含 ClientSecret)能通過 Validate。
|
||||||
|
func TestValidate_ConfidentialClient(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
IssuerURL: "http://localhost:5050",
|
||||||
|
ClientID: "visionA",
|
||||||
|
ClientSecret: "secret", // 有值 → confidential mode
|
||||||
|
RedirectURL: "http://localhost:3721/api/auth/callback",
|
||||||
|
PostLoginURL: "http://localhost:3000",
|
||||||
|
},
|
||||||
|
UserSession: UserSessionConfig{Secret: "session-secret-32-bytes-aaaaaaaaaaaa"},
|
||||||
|
}
|
||||||
|
assert.NoError(t, cfg.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_PKCEOnlyPublicClient:A1 — 只給 ClientID 沒給 Secret 也能通過 Validate。
|
||||||
|
//
|
||||||
|
// 對應 Stage 部署的真實情境:MC 配給 visionA 的 client `b8093fea1a504a5d8f0e04bee9f78f2e`
|
||||||
|
// 是 public client,沒有 client_secret,靠 PKCE 防 code interception。
|
||||||
|
func TestValidate_PKCEOnlyPublicClient(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
IssuerURL: "https://stage-9527.innovedus.com:7850/",
|
||||||
|
ClientID: "b8093fea1a504a5d8f0e04bee9f78f2e",
|
||||||
|
// ClientSecret 留空 — public PKCE-only client
|
||||||
|
RedirectURL: "https://stage-9527.innovedus.com:9527/api/auth/callback",
|
||||||
|
PostLoginURL: "https://stage-9527.innovedus.com:9527",
|
||||||
|
},
|
||||||
|
UserSession: UserSessionConfig{Secret: "session-secret-32-bytes-aaaaaaaaaaaa"},
|
||||||
|
}
|
||||||
|
assert.NoError(t, cfg.Validate(),
|
||||||
|
"A1:public PKCE-only client(ClientSecret 留空)應通過 Validate")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestValidate_ServiceClientFieldsNotChecked:A1 — ServiceClientID/Secret 留空不影響 Validate。
|
||||||
|
//
|
||||||
|
// 兩個欄位是 client_credentials grant 預留鉤子,A1 階段不啟用、不檢查。
|
||||||
|
func TestValidate_ServiceClientFieldsNotChecked(t *testing.T) {
|
||||||
|
cfg := &Config{
|
||||||
|
OIDC: OIDCConfig{
|
||||||
|
IssuerURL: "http://localhost:5050",
|
||||||
|
ClientID: "visionA",
|
||||||
|
RedirectURL: "http://localhost:3721/api/auth/callback",
|
||||||
|
PostLoginURL: "http://localhost:3000",
|
||||||
|
// 兩個 Service* 都留空 — 預期通過
|
||||||
|
},
|
||||||
|
UserSession: UserSessionConfig{Secret: "session-secret-32-bytes-aaaaaaaaaaaa"},
|
||||||
|
}
|
||||||
|
assert.NoError(t, cfg.Validate())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestLoad_CORSAllowedOrigins 驗證 VISIONA_CORS_ALLOWED_ORIGINS 的逗號分隔解析。
|
||||||
|
// 空字串 / 純分隔字元 → fallback 到 nil(交由 api.Deps.validate 塞預設)。
|
||||||
|
func TestLoad_CORSAllowedOrigins(t *testing.T) {
|
||||||
|
// 未設 → nil
|
||||||
|
t.Setenv("VISIONA_CORS_ALLOWED_ORIGINS", "")
|
||||||
|
cfg := Load()
|
||||||
|
assert.Nil(t, cfg.CORS.AllowedOrigins)
|
||||||
|
|
||||||
|
// 單一 origin
|
||||||
|
t.Setenv("VISIONA_CORS_ALLOWED_ORIGINS", "http://localhost:3000")
|
||||||
|
cfg = Load()
|
||||||
|
assert.Equal(t, []string{"http://localhost:3000"}, cfg.CORS.AllowedOrigins)
|
||||||
|
|
||||||
|
// 多個 origin + trim space
|
||||||
|
t.Setenv("VISIONA_CORS_ALLOWED_ORIGINS", "http://a.com, http://b.com ,http://c.com")
|
||||||
|
cfg = Load()
|
||||||
|
assert.Equal(t, []string{"http://a.com", "http://b.com", "http://c.com"}, cfg.CORS.AllowedOrigins)
|
||||||
|
|
||||||
|
// 只有分隔字元 → fallback(過濾後 len == 0)
|
||||||
|
t.Setenv("VISIONA_CORS_ALLOWED_ORIGINS", " , ,")
|
||||||
|
cfg = Load()
|
||||||
|
assert.Nil(t, cfg.CORS.AllowedOrigins)
|
||||||
|
}
|
||||||
81
visionA-backend/internal/converter/converter.go
Normal file
81
visionA-backend/internal/converter/converter.go
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
// Package converter 定義與 kneron_model_converter 服務互動的 client 介面。
|
||||||
|
//
|
||||||
|
// 對齊 TDD §2.7 與 api/api-converter-contract.md。
|
||||||
|
// 雛形以 StubClient 實作所有方法回 ErrNotImplemented(部分關鍵方法可給假資料讓前端走通 UI);
|
||||||
|
// Phase 2 以 HTTPClient 實作同 interface 呼叫真實 converter。
|
||||||
|
package converter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Errors
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotImplemented 表示雛形尚未實作此方法。
|
||||||
|
ErrNotImplemented = errors.New("converter: not implemented in phase 0")
|
||||||
|
|
||||||
|
// ErrJobNotFound 表示指定 jobID 不存在。
|
||||||
|
ErrJobNotFound = errors.New("converter: job not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Domain types(對齊 database.md §2.6)
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Job 是轉檔任務的狀態快照。
|
||||||
|
type Job struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
OwnerUserID string `json:"ownerUserId"`
|
||||||
|
Status string `json:"status"` // queued / running / succeeded / failed
|
||||||
|
SourceKey string `json:"sourceKey"`
|
||||||
|
ResultKey string `json:"resultKey,omitempty"`
|
||||||
|
TargetChip string `json:"targetChip"`
|
||||||
|
Params map[string]any `json:"params,omitempty"`
|
||||||
|
|
||||||
|
ErrorCode string `json:"errorCode,omitempty"`
|
||||||
|
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
StartedAt *time.Time `json:"startedAt,omitempty"`
|
||||||
|
CompletedAt *time.Time `json:"completedAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertRequest 是提交轉檔任務時的輸入參數。
|
||||||
|
type ConvertRequest struct {
|
||||||
|
OwnerUserID string `json:"ownerUserId"`
|
||||||
|
SourceKey string `json:"sourceKey"` // 已上傳到 Storage 的來源檔 key
|
||||||
|
TargetChip string `json:"targetChip"`
|
||||||
|
Params map[string]any `json:"params,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Client interface
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Client 抽象 converter 服務。
|
||||||
|
//
|
||||||
|
// 對齊 PRD interface-contracts.md §8.5 與 api-converter-contract.md。
|
||||||
|
type Client interface {
|
||||||
|
// SubmitConvert 提交一個新的轉檔任務;回傳 jobID。
|
||||||
|
SubmitConvert(ctx context.Context, req *ConvertRequest) (jobID string, err error)
|
||||||
|
|
||||||
|
// GetJob 查詢任務狀態;不存在回 ErrJobNotFound。
|
||||||
|
GetJob(ctx context.Context, jobID string) (*Job, error)
|
||||||
|
|
||||||
|
// ListJobs 列出使用者的所有轉檔任務。
|
||||||
|
ListJobs(ctx context.Context, userID string) ([]*Job, error)
|
||||||
|
|
||||||
|
// DownloadResult 下載任務產物(.nef)。
|
||||||
|
// 未完成或失敗時回錯;caller 必須 Close reader。
|
||||||
|
DownloadResult(ctx context.Context, jobID string) (io.ReadCloser, error)
|
||||||
|
|
||||||
|
// CancelJob 取消任務。
|
||||||
|
CancelJob(ctx context.Context, jobID string) error
|
||||||
|
}
|
||||||
45
visionA-backend/internal/converter/stub.go
Normal file
45
visionA-backend/internal/converter/stub.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package converter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StubClient 是 Phase 0 的 converter Client stub。
|
||||||
|
//
|
||||||
|
// 所有方法回 ErrNotImplemented,用於讓 DI 流程能編譯運行,
|
||||||
|
// 但前端若真的呼叫到 converter API,會收到 501 / 明確錯誤訊息。
|
||||||
|
//
|
||||||
|
// 未來若需要假資料讓前端 UI 流程走通(PRD §8.5 建議),
|
||||||
|
// 可擴充為 FakeClient(產 fake job_id、模擬 queued → processing → completed)。
|
||||||
|
type StubClient struct{}
|
||||||
|
|
||||||
|
// NewStubClient 建立一個 StubClient 實例。
|
||||||
|
func NewStubClient() *StubClient {
|
||||||
|
return &StubClient{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SubmitConvert 回 ErrNotImplemented。
|
||||||
|
func (s *StubClient) SubmitConvert(ctx context.Context, req *ConvertRequest) (string, error) {
|
||||||
|
return "", ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJob 回 ErrNotImplemented。
|
||||||
|
func (s *StubClient) GetJob(ctx context.Context, jobID string) (*Job, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListJobs 回 ErrNotImplemented。
|
||||||
|
func (s *StubClient) ListJobs(ctx context.Context, userID string) ([]*Job, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownloadResult 回 ErrNotImplemented。
|
||||||
|
func (s *StubClient) DownloadResult(ctx context.Context, jobID string) (io.ReadCloser, error) {
|
||||||
|
return nil, ErrNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelJob 回 ErrNotImplemented。
|
||||||
|
func (s *StubClient) CancelJob(ctx context.Context, jobID string) error {
|
||||||
|
return ErrNotImplemented
|
||||||
|
}
|
||||||
30
visionA-backend/internal/converter/stub_test.go
Normal file
30
visionA-backend/internal/converter/stub_test.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package converter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStubClient_AllMethodsReturnNotImplemented(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
c := NewStubClient()
|
||||||
|
|
||||||
|
_, err := c.SubmitConvert(ctx, &ConvertRequest{})
|
||||||
|
assert.ErrorIs(t, err, ErrNotImplemented)
|
||||||
|
|
||||||
|
_, err = c.GetJob(ctx, "any")
|
||||||
|
assert.ErrorIs(t, err, ErrNotImplemented)
|
||||||
|
|
||||||
|
_, err = c.ListJobs(ctx, "user")
|
||||||
|
assert.ErrorIs(t, err, ErrNotImplemented)
|
||||||
|
|
||||||
|
_, err = c.DownloadResult(ctx, "job")
|
||||||
|
assert.ErrorIs(t, err, ErrNotImplemented)
|
||||||
|
|
||||||
|
assert.ErrorIs(t, c.CancelJob(ctx, "job"), ErrNotImplemented)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 確保 StubClient 滿足 Client interface(編譯時檢查)。
|
||||||
|
var _ Client = (*StubClient)(nil)
|
||||||
213
visionA-backend/internal/device/device.go
Normal file
213
visionA-backend/internal/device/device.go
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
// Package device 定義 Device domain model 與 Repository 介面。
|
||||||
|
//
|
||||||
|
// 對齊 database.md §2.2(雙狀態模型 — Minor-3)與 §3(Repository interface)。
|
||||||
|
// 雛形以 InMemoryRepository 實作;Phase 1 新增 PostgresRepository 取代。
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Errors
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotFound 表示指定 ID 的 Device 不存在。
|
||||||
|
ErrNotFound = errors.New("device: not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Remote / USB 狀態常數(對齊 database.md §2.2)
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// RemoteStatus 是雲端對 tunnel 連線的觀察值。
|
||||||
|
type RemoteStatus = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RemoteStatusOnline 表示 tunnel 有效、雲端可達。
|
||||||
|
RemoteStatusOnline RemoteStatus = "online"
|
||||||
|
// RemoteStatusOffline 表示 tunnel 斷線或從未連上。
|
||||||
|
RemoteStatusOffline RemoteStatus = "offline"
|
||||||
|
// RemoteStatusReconnecting 表示 tunnel 短暫斷線、local agent 重連中。
|
||||||
|
RemoteStatusReconnecting RemoteStatus = "reconnecting"
|
||||||
|
// RemoteStatusError 表示 tunnel 發生未預期錯誤(yamux 異常等)。
|
||||||
|
RemoteStatusError RemoteStatus = "error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// USBStatus 是 local agent 從 Kneron SDK 讀到的 USB 狀態。
|
||||||
|
type USBStatus = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// USBStatusOnline USB 插著且可用。
|
||||||
|
USBStatusOnline USBStatus = "online"
|
||||||
|
// USBStatusOffline USB 拔掉了。
|
||||||
|
USBStatusOffline USBStatus = "offline"
|
||||||
|
// USBStatusUnknown 尚未回報 / 初始狀態。
|
||||||
|
USBStatusUnknown USBStatus = "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Device struct
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Device 對應 database.md §2.2 的 Device 實體。
|
||||||
|
//
|
||||||
|
// 雙狀態說明(Minor-3):
|
||||||
|
// - Status(USB-level):local agent 觀察到的 USB 連接狀態
|
||||||
|
// - RemoteStatus(tunnel-level):雲端觀察到的 tunnel 連線狀態
|
||||||
|
//
|
||||||
|
// 前端優先顯示 RemoteStatus,次要顯示 Status(見 TDD §10.5.1)。
|
||||||
|
type Device struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
OwnerUserID string `json:"ownerUserId"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
DeviceType string `json:"deviceType"`
|
||||||
|
SerialNumber string `json:"serialNumber,omitempty"`
|
||||||
|
|
||||||
|
// tunnel-level 狀態
|
||||||
|
RemoteStatus RemoteStatus `json:"remoteStatus"`
|
||||||
|
LastSeenAt *time.Time `json:"lastSeenAt,omitempty"`
|
||||||
|
LastConnectedAt *time.Time `json:"lastConnectedAt,omitempty"`
|
||||||
|
|
||||||
|
// USB-level 狀態
|
||||||
|
Status USBStatus `json:"status"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
PairedAt *time.Time `json:"pairedAt,omitempty"`
|
||||||
|
DeletedAt *time.Time `json:"deletedAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Repository interface
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Repository 是 Device 持久層介面。
|
||||||
|
//
|
||||||
|
// 所有查詢方法**必須略過 DeletedAt != nil 的紀錄**(soft delete)。
|
||||||
|
// Phase 1 的 PostgresRepository 會加上 `WHERE deleted_at IS NULL`。
|
||||||
|
type Repository interface {
|
||||||
|
// Get 取得單一 device;不存在或已軟刪除回 ErrNotFound。
|
||||||
|
Get(ctx context.Context, id string) (*Device, error)
|
||||||
|
|
||||||
|
// GetBySerial 以 (ownerUserID, serialNumber) 查詢(避免同 user 重複註冊同 serial)。
|
||||||
|
GetBySerial(ctx context.Context, ownerUserID, serial string) (*Device, error)
|
||||||
|
|
||||||
|
// List 列出某 user 的所有(未刪除)device。
|
||||||
|
List(ctx context.Context, ownerUserID string) ([]*Device, error)
|
||||||
|
|
||||||
|
// Save 新增或更新一筆 device(upsert 語意,by ID)。
|
||||||
|
// 實作應更新 UpdatedAt;若為新建則同時設定 CreatedAt。
|
||||||
|
Save(ctx context.Context, d *Device) error
|
||||||
|
|
||||||
|
// Delete 標記為軟刪除(設定 DeletedAt)。
|
||||||
|
Delete(ctx context.Context, id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// InMemoryRepository
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// InMemoryRepository 是 Phase 0 雛形的記憶體實作。
|
||||||
|
type InMemoryRepository struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
devices map[string]*Device
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemoryRepository 建立一個空的記憶體 Repository。
|
||||||
|
func NewInMemoryRepository() *InMemoryRepository {
|
||||||
|
return &InMemoryRepository{
|
||||||
|
devices: make(map[string]*Device),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 取得單一 device。
|
||||||
|
func (r *InMemoryRepository) Get(ctx context.Context, id string) (*Device, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
d, ok := r.devices[id]
|
||||||
|
if !ok || d.DeletedAt != nil {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
cp := *d
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBySerial 以 (owner, serial) 查詢。
|
||||||
|
func (r *InMemoryRepository) GetBySerial(ctx context.Context, ownerUserID, serial string) (*Device, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
for _, d := range r.devices {
|
||||||
|
if d.DeletedAt != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if d.OwnerUserID == ownerUserID && d.SerialNumber == serial {
|
||||||
|
cp := *d
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 列出某 user 的所有未刪除 device。
|
||||||
|
func (r *InMemoryRepository) List(ctx context.Context, ownerUserID string) ([]*Device, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
out := make([]*Device, 0)
|
||||||
|
for _, d := range r.devices {
|
||||||
|
if d.DeletedAt != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if d.OwnerUserID == ownerUserID {
|
||||||
|
cp := *d
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save 新增或更新 device(upsert by ID)。
|
||||||
|
func (r *InMemoryRepository) Save(ctx context.Context, d *Device) error {
|
||||||
|
if d == nil || d.ID == "" {
|
||||||
|
return errors.New("device: Save requires non-nil device with ID")
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
// Copy 避免外部後續修改影響 store
|
||||||
|
cp := *d
|
||||||
|
if existing, ok := r.devices[d.ID]; ok && existing.DeletedAt == nil {
|
||||||
|
cp.CreatedAt = existing.CreatedAt // 保留原始 CreatedAt
|
||||||
|
} else if cp.CreatedAt.IsZero() {
|
||||||
|
cp.CreatedAt = now
|
||||||
|
}
|
||||||
|
cp.UpdatedAt = now
|
||||||
|
r.devices[d.ID] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 標記 device 為軟刪除。
|
||||||
|
func (r *InMemoryRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
d, ok := r.devices[id]
|
||||||
|
if !ok || d.DeletedAt != nil {
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
d.DeletedAt = &now
|
||||||
|
d.UpdatedAt = now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 InMemoryRepository 實作 Repository。
|
||||||
|
var _ Repository = (*InMemoryRepository)(nil)
|
||||||
120
visionA-backend/internal/device/inmemory_repository_test.go
Normal file
120
visionA-backend/internal/device/inmemory_repository_test.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInMemoryRepository_SaveAndGet(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
d := &Device{
|
||||||
|
ID: "dev-1",
|
||||||
|
OwnerUserID: "user-1",
|
||||||
|
Name: "Lab KL520",
|
||||||
|
DeviceType: "kl520",
|
||||||
|
SerialNumber: "KL520-AAA",
|
||||||
|
RemoteStatus: RemoteStatusOffline,
|
||||||
|
Status: USBStatusUnknown,
|
||||||
|
}
|
||||||
|
require.NoError(t, r.Save(ctx, d))
|
||||||
|
|
||||||
|
got, err := r.Get(ctx, "dev-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Lab KL520", got.Name)
|
||||||
|
assert.False(t, got.CreatedAt.IsZero())
|
||||||
|
assert.False(t, got.UpdatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Get_NotFound(t *testing.T) {
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
_, err := r.Get(context.Background(), "nope")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Save_RequiresID(t *testing.T) {
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
err := r.Save(context.Background(), &Device{Name: "no-id"})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_GetBySerial(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{
|
||||||
|
ID: "dev-1", OwnerUserID: "user-A", SerialNumber: "S-1",
|
||||||
|
}))
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{
|
||||||
|
ID: "dev-2", OwnerUserID: "user-B", SerialNumber: "S-1",
|
||||||
|
}))
|
||||||
|
|
||||||
|
got, err := r.GetBySerial(ctx, "user-A", "S-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "dev-1", got.ID)
|
||||||
|
|
||||||
|
_, err = r.GetBySerial(ctx, "user-C", "S-1")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_List_ByOwner(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{ID: "a", OwnerUserID: "u1"}))
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{ID: "b", OwnerUserID: "u1"}))
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{ID: "c", OwnerUserID: "u2"}))
|
||||||
|
|
||||||
|
listU1, err := r.List(ctx, "u1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listU1, 2)
|
||||||
|
|
||||||
|
listU3, err := r.List(ctx, "u3")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, listU3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Delete_SoftDelete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{ID: "dev-1", OwnerUserID: "u"}))
|
||||||
|
|
||||||
|
require.NoError(t, r.Delete(ctx, "dev-1"))
|
||||||
|
|
||||||
|
// Get 應該找不到
|
||||||
|
_, err := r.Get(ctx, "dev-1")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
|
||||||
|
// List 也不該列出
|
||||||
|
list, _ := r.List(ctx, "u")
|
||||||
|
assert.Empty(t, list)
|
||||||
|
|
||||||
|
// 再次 Delete 應回 ErrNotFound(已軟刪除)
|
||||||
|
assert.ErrorIs(t, r.Delete(ctx, "dev-1"), ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Save_PreservesCreatedAt(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Device{ID: "dev-1", OwnerUserID: "u"}))
|
||||||
|
first, err := r.Get(ctx, "dev-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
createdAt := first.CreatedAt
|
||||||
|
|
||||||
|
// 更新(應保留 CreatedAt)
|
||||||
|
updated := *first
|
||||||
|
updated.Name = "Updated"
|
||||||
|
require.NoError(t, r.Save(ctx, &updated))
|
||||||
|
|
||||||
|
got, err := r.Get(ctx, "dev-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "Updated", got.Name)
|
||||||
|
assert.Equal(t, createdAt, got.CreatedAt, "CreatedAt 應保留原值")
|
||||||
|
assert.True(t, got.UpdatedAt.After(createdAt) || got.UpdatedAt.Equal(createdAt))
|
||||||
|
}
|
||||||
39
visionA-backend/internal/logger/logger.go
Normal file
39
visionA-backend/internal/logger/logger.go
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
// Package logger 提供最小化的結構化 JSON logger,建構於 Go 1.21+ 的 log/slog。
|
||||||
|
//
|
||||||
|
// 設計原則:
|
||||||
|
// - 所有日誌為 JSON line,便於雲端 log aggregator 解析(CloudWatch / Loki / Datadog)。
|
||||||
|
// - 不包太多層 — 直接回傳 *slog.Logger,由呼叫端自由使用 slog 的 API。
|
||||||
|
// - 輸出到 stdout(12-Factor App 第 11 條:logs as event streams)。
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New 建立一個輸出為 JSON 的結構化 logger。
|
||||||
|
//
|
||||||
|
// level 接受 "debug" / "info" / "warn" / "error"(大小寫不敏感);
|
||||||
|
// 無法解析時預設為 info。
|
||||||
|
func New(level string) *slog.Logger {
|
||||||
|
handler := slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{
|
||||||
|
Level: parseLevel(level),
|
||||||
|
AddSource: false, // 需要時再開;預設關閉以降低額外成本
|
||||||
|
})
|
||||||
|
return slog.New(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseLevel 將字串解析為 slog.Level,無法解析時回傳 LevelInfo。
|
||||||
|
func parseLevel(s string) slog.Level {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(s)) {
|
||||||
|
case "debug":
|
||||||
|
return slog.LevelDebug
|
||||||
|
case "warn", "warning":
|
||||||
|
return slog.LevelWarn
|
||||||
|
case "error", "err":
|
||||||
|
return slog.LevelError
|
||||||
|
default:
|
||||||
|
return slog.LevelInfo
|
||||||
|
}
|
||||||
|
}
|
||||||
35
visionA-backend/internal/logger/logger_test.go
Normal file
35
visionA-backend/internal/logger/logger_test.go
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNew_ReturnsNonNil(t *testing.T) {
|
||||||
|
l := New("info")
|
||||||
|
assert.NotNil(t, l)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseLevel(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want slog.Level
|
||||||
|
}{
|
||||||
|
{"debug", slog.LevelDebug},
|
||||||
|
{"DEBUG", slog.LevelDebug},
|
||||||
|
{"info", slog.LevelInfo},
|
||||||
|
{"warn", slog.LevelWarn},
|
||||||
|
{"warning", slog.LevelWarn},
|
||||||
|
{"error", slog.LevelError},
|
||||||
|
{"err", slog.LevelError},
|
||||||
|
{"", slog.LevelInfo},
|
||||||
|
{"invalid", slog.LevelInfo},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.in, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tc.want, parseLevel(tc.in))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
105
visionA-backend/internal/model/inmemory_repository_test.go
Normal file
105
visionA-backend/internal/model/inmemory_repository_test.go
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInMemoryRepository_SaveAndGet(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
m := &Model{
|
||||||
|
ID: "m-1",
|
||||||
|
OwnerUserID: "user-1",
|
||||||
|
Name: "yolo-v5",
|
||||||
|
StorageKey: "models/user-1/m-1.nef",
|
||||||
|
FileSize: 1024 * 1024,
|
||||||
|
Source: SourceUploaded,
|
||||||
|
TargetChip: "kl520",
|
||||||
|
}
|
||||||
|
require.NoError(t, r.Save(ctx, m))
|
||||||
|
|
||||||
|
got, err := r.Get(ctx, "m-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "yolo-v5", got.Name)
|
||||||
|
assert.False(t, got.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Get_NotFound(t *testing.T) {
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
_, err := r.Get(context.Background(), "nope")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_List_Filter(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Model{ID: "1", OwnerUserID: "u1", TargetChip: "kl520", Source: SourceUploaded}))
|
||||||
|
require.NoError(t, r.Save(ctx, &Model{ID: "2", OwnerUserID: "u1", TargetChip: "kl720", Source: SourceConverted}))
|
||||||
|
require.NoError(t, r.Save(ctx, &Model{ID: "3", OwnerUserID: "u2", TargetChip: "kl520", Source: SourceUploaded}))
|
||||||
|
|
||||||
|
// 依 owner 過濾
|
||||||
|
list, err := r.List(ctx, ListFilter{OwnerUserID: "u1"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, list, 2)
|
||||||
|
|
||||||
|
// 依 chip 過濾
|
||||||
|
list, err = r.List(ctx, ListFilter{OwnerUserID: "u1", TargetChip: "kl520"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, list, 1)
|
||||||
|
assert.Equal(t, "1", list[0].ID)
|
||||||
|
|
||||||
|
// 依 source 過濾
|
||||||
|
list, err = r.List(ctx, ListFilter{Source: SourceConverted})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, list, 1)
|
||||||
|
|
||||||
|
// 無 owner 過濾(admin 用)
|
||||||
|
list, err = r.List(ctx, ListFilter{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, list, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Delete_SoftDelete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
|
||||||
|
require.NoError(t, r.Save(ctx, &Model{ID: "1", OwnerUserID: "u"}))
|
||||||
|
require.NoError(t, r.Delete(ctx, "1"))
|
||||||
|
|
||||||
|
_, err := r.Get(ctx, "1")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
|
||||||
|
list, _ := r.List(ctx, ListFilter{OwnerUserID: "u"})
|
||||||
|
assert.Empty(t, list)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryRepository_Save_RequiresID(t *testing.T) {
|
||||||
|
r := NewInMemoryRepository()
|
||||||
|
assert.Error(t, r.Save(context.Background(), &Model{Name: "no-id"}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSizeValidator_Check(t *testing.T) {
|
||||||
|
v := NewSizeValidator(100) // 100 MB
|
||||||
|
|
||||||
|
// 剛好在 limit 以內
|
||||||
|
assert.NoError(t, v.Check(100*1024*1024))
|
||||||
|
assert.NoError(t, v.Check(50*1024*1024))
|
||||||
|
|
||||||
|
// 超過
|
||||||
|
err := v.Check(101 * 1024 * 1024)
|
||||||
|
assert.ErrorIs(t, err, ErrFileTooLarge)
|
||||||
|
|
||||||
|
// 大小為 0 不是錯誤(Repository 不管)
|
||||||
|
assert.NoError(t, v.Check(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSizeValidator_NoLimit(t *testing.T) {
|
||||||
|
v := NewSizeValidator(0) // 0 or negative 視為無限制
|
||||||
|
assert.NoError(t, v.Check(1024*1024*1024*10)) // 10 GB
|
||||||
|
}
|
||||||
223
visionA-backend/internal/model/model.go
Normal file
223
visionA-backend/internal/model/model.go
Normal file
@ -0,0 +1,223 @@
|
|||||||
|
// Package model 定義 Model domain(KL 推論模型檔)與 Repository 介面。
|
||||||
|
//
|
||||||
|
// 對齊 database.md §2.3。雛形以 InMemoryRepository 實作;
|
||||||
|
// Phase 1 以 PostgresRepository 取代(同 interface)。
|
||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Errors
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotFound 表示指定 ID 的 Model 不存在。
|
||||||
|
ErrNotFound = errors.New("model: not found")
|
||||||
|
|
||||||
|
// ErrFileTooLarge 表示上傳檔案超過配置的大小上限(MB)。
|
||||||
|
// 由 service 層檢查並回傳;Repository 層本身不驗。
|
||||||
|
ErrFileTooLarge = errors.New("model: file too large")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Source 常數
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Source 描述 Model 的來源。
|
||||||
|
type Source = string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SourceUploaded 使用者直接上傳。
|
||||||
|
SourceUploaded Source = "uploaded"
|
||||||
|
// SourceConverted 透過 converter 產生。
|
||||||
|
SourceConverted Source = "converted"
|
||||||
|
// SourcePreset 系統預設模型。
|
||||||
|
SourcePreset Source = "preset"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Model struct(對齊 database.md §2.3)
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Model 是 KL 推論用的模型檔(通常 .nef 格式)。
|
||||||
|
type Model struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
OwnerUserID string `json:"ownerUserId"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
|
||||||
|
// 檔案資訊
|
||||||
|
StorageKey string `json:"storageKey"`
|
||||||
|
FileSize int64 `json:"fileSize"`
|
||||||
|
FileChecksum string `json:"fileChecksum,omitempty"` // sha256 hex
|
||||||
|
|
||||||
|
// 模型 metadata(可選)
|
||||||
|
TargetChip string `json:"targetChip,omitempty"`
|
||||||
|
InputShape []int `json:"inputShape,omitempty"`
|
||||||
|
Classes []string `json:"classes,omitempty"`
|
||||||
|
Framework string `json:"framework,omitempty"`
|
||||||
|
|
||||||
|
// 來源
|
||||||
|
Source Source `json:"source"`
|
||||||
|
SourceJobID string `json:"sourceJobId,omitempty"`
|
||||||
|
|
||||||
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
UpdatedAt time.Time `json:"updatedAt"`
|
||||||
|
UploadedAt *time.Time `json:"uploadedAt,omitempty"`
|
||||||
|
DeletedAt *time.Time `json:"deletedAt,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Filter / Repository
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// ListFilter 提供 List 方法的可選篩選條件。
|
||||||
|
type ListFilter struct {
|
||||||
|
OwnerUserID string // 必填於一般業務查詢;空字串表示不過濾(僅供管理用)
|
||||||
|
TargetChip string // 可選
|
||||||
|
Source Source // 可選
|
||||||
|
}
|
||||||
|
|
||||||
|
// Repository 是 Model 持久層介面。
|
||||||
|
//
|
||||||
|
// 所有查詢必須略過 DeletedAt != nil 的紀錄。
|
||||||
|
type Repository interface {
|
||||||
|
// Get 取得單一 Model;不存在或已刪除回 ErrNotFound。
|
||||||
|
Get(ctx context.Context, id string) (*Model, error)
|
||||||
|
|
||||||
|
// List 依 filter 列出 Model;filter.OwnerUserID 不同於空字串時限定擁有者。
|
||||||
|
List(ctx context.Context, filter ListFilter) ([]*Model, error)
|
||||||
|
|
||||||
|
// Save 新增或更新 Model(upsert by ID)。
|
||||||
|
Save(ctx context.Context, m *Model) error
|
||||||
|
|
||||||
|
// Delete 軟刪除。
|
||||||
|
Delete(ctx context.Context, id string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// SizeValidator — 依 Config.Model.MaxSizeMB 驗證檔案大小
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// SizeValidator 提供 Model 上傳大小上限檢查。
|
||||||
|
//
|
||||||
|
// 由 api handler / service 層呼叫;Repository 不耦合此邏輯。
|
||||||
|
type SizeValidator struct {
|
||||||
|
MaxSizeMB int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSizeValidator 建立檔案大小驗證器;maxSizeMB <= 0 時視為無限制(不建議生產用)。
|
||||||
|
func NewSizeValidator(maxSizeMB int) *SizeValidator {
|
||||||
|
return &SizeValidator{MaxSizeMB: maxSizeMB}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check 檢查 size(bytes)是否超過上限,超過回 ErrFileTooLarge。
|
||||||
|
func (v *SizeValidator) Check(size int64) error {
|
||||||
|
if v.MaxSizeMB <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
limit := int64(v.MaxSizeMB) * 1024 * 1024
|
||||||
|
if size > limit {
|
||||||
|
return fmt.Errorf("%w: %d bytes exceeds %d MB limit", ErrFileTooLarge, size, v.MaxSizeMB)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// InMemoryRepository
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// InMemoryRepository 是 Phase 0 的記憶體實作。
|
||||||
|
type InMemoryRepository struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
models map[string]*Model
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemoryRepository 建立一個空的記憶體 Repository。
|
||||||
|
func NewInMemoryRepository() *InMemoryRepository {
|
||||||
|
return &InMemoryRepository{
|
||||||
|
models: make(map[string]*Model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 取得單一 Model。
|
||||||
|
func (r *InMemoryRepository) Get(ctx context.Context, id string) (*Model, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
m, ok := r.models[id]
|
||||||
|
if !ok || m.DeletedAt != nil {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
cp := *m
|
||||||
|
return &cp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 依條件列出 Model。
|
||||||
|
func (r *InMemoryRepository) List(ctx context.Context, filter ListFilter) ([]*Model, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
out := make([]*Model, 0)
|
||||||
|
for _, m := range r.models {
|
||||||
|
if m.DeletedAt != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.OwnerUserID != "" && m.OwnerUserID != filter.OwnerUserID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.TargetChip != "" && m.TargetChip != filter.TargetChip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if filter.Source != "" && m.Source != filter.Source {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
cp := *m
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save 新增或更新 Model(upsert by ID)。
|
||||||
|
func (r *InMemoryRepository) Save(ctx context.Context, m *Model) error {
|
||||||
|
if m == nil || m.ID == "" {
|
||||||
|
return errors.New("model: Save requires non-nil model with ID")
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
cp := *m
|
||||||
|
if existing, ok := r.models[m.ID]; ok && existing.DeletedAt == nil {
|
||||||
|
cp.CreatedAt = existing.CreatedAt
|
||||||
|
} else if cp.CreatedAt.IsZero() {
|
||||||
|
cp.CreatedAt = now
|
||||||
|
}
|
||||||
|
cp.UpdatedAt = now
|
||||||
|
r.models[m.ID] = &cp
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 軟刪除。
|
||||||
|
func (r *InMemoryRepository) Delete(ctx context.Context, id string) error {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
m, ok := r.models[id]
|
||||||
|
if !ok || m.DeletedAt != nil {
|
||||||
|
return ErrNotFound
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
m.DeletedAt = &now
|
||||||
|
m.UpdatedAt = now
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 InMemoryRepository 實作 Repository。
|
||||||
|
var _ Repository = (*InMemoryRepository)(nil)
|
||||||
55
visionA-backend/internal/oidc/errors.go
Normal file
55
visionA-backend/internal/oidc/errors.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
// Package oidc 提供 OpenID Connect (Authorization Code + PKCE) client 的封裝,
|
||||||
|
// 對接 Innovedus Member Center 或任何 OIDC compliant Identity Provider。
|
||||||
|
//
|
||||||
|
// 設計對齊:
|
||||||
|
// - oidc-tdd.md §4.2(internal/oidc/ 模組)
|
||||||
|
// - oidc-tdd.md §6(PKCE 細節)
|
||||||
|
// - oidc-tdd.md §7(id_token 驗證)
|
||||||
|
// - adr-010-oidc-bff.md
|
||||||
|
//
|
||||||
|
// 此 package 僅提供「OIDC client wrapper」職責:
|
||||||
|
// - Discovery / JWKS(藉由 coreos/go-oidc/v3 實作,自動快取)
|
||||||
|
// - PKCE / state / nonce 隨機值產生
|
||||||
|
// - Authorization URL 組裝
|
||||||
|
// - Authorization Code → Token Exchange
|
||||||
|
// - id_token 驗證(簽章 + claim)
|
||||||
|
//
|
||||||
|
// 不負責:HTTP handler、cookie session、frontend redirect — 這些由 OB3-OB4 處理。
|
||||||
|
package oidc
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// 公開 sentinel errors,便於 caller 用 errors.Is 比對。
|
||||||
|
// 命名與 internal/auth 風格一致。
|
||||||
|
var (
|
||||||
|
// ErrDiscoveryFetch 表示 .well-known/openid-configuration 抓取或解析失敗。
|
||||||
|
ErrDiscoveryFetch = errors.New("oidc: discovery fetch failed")
|
||||||
|
|
||||||
|
// ErrJWKSFetch 表示 jwks_uri 抓取或解析失敗。
|
||||||
|
ErrJWKSFetch = errors.New("oidc: jwks fetch failed")
|
||||||
|
|
||||||
|
// ErrTokenExchange 表示 token endpoint 回傳錯誤(非 401)或網路錯誤。
|
||||||
|
ErrTokenExchange = errors.New("oidc: token exchange failed")
|
||||||
|
|
||||||
|
// ErrInvalidGrant 表示 authorization code 已被使用、過期、或 PKCE verifier 不符(HTTP 400/401 with invalid_grant)。
|
||||||
|
ErrInvalidGrant = errors.New("oidc: invalid grant")
|
||||||
|
|
||||||
|
// ErrInvalidIDToken 是 id_token 驗證失敗的 umbrella error;
|
||||||
|
// 包裹下方更精確的 sentinel,caller 可用 errors.Is 逐個檢查。
|
||||||
|
ErrInvalidIDToken = errors.New("oidc: invalid id_token")
|
||||||
|
|
||||||
|
// ErrInvalidIssuer 表示 id_token 的 iss claim 不等於 cfg.IssuerURL。
|
||||||
|
ErrInvalidIssuer = errors.New("oidc: invalid issuer")
|
||||||
|
|
||||||
|
// ErrInvalidAudience 表示 id_token 的 aud claim 不包含 cfg.ClientID。
|
||||||
|
ErrInvalidAudience = errors.New("oidc: invalid audience")
|
||||||
|
|
||||||
|
// ErrTokenExpired 表示 id_token 已過期(exp <= now,含 leeway)。
|
||||||
|
ErrTokenExpired = errors.New("oidc: id_token expired")
|
||||||
|
|
||||||
|
// ErrInvalidNonce 表示 id_token 的 nonce claim 與 caller 提供的 expectedNonce 不符。
|
||||||
|
ErrInvalidNonce = errors.New("oidc: invalid nonce")
|
||||||
|
|
||||||
|
// ErrInvalidConfig 表示 ProviderConfig 缺欄位或欄位格式錯誤。
|
||||||
|
ErrInvalidConfig = errors.New("oidc: invalid config")
|
||||||
|
)
|
||||||
111
visionA-backend/internal/oidc/oidc.go
Normal file
111
visionA-backend/internal/oidc/oidc.go
Normal file
@ -0,0 +1,111 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderConfig 是建立 Provider 所需的所有設定。
|
||||||
|
//
|
||||||
|
// 全部欄位都從環境變數帶入(見 oidc-tdd.md §13.1),不在程式碼中 hardcode。
|
||||||
|
// caller 應在啟動時驗證所有必填欄位非空。
|
||||||
|
type ProviderConfig struct {
|
||||||
|
// IssuerURL 是 OIDC Identity Provider 的 issuer,例如 https://member-center.dev.innovedus.com
|
||||||
|
// (結尾不帶斜線)。NewProvider 會以此為 base 抓 .well-known/openid-configuration。
|
||||||
|
IssuerURL string
|
||||||
|
|
||||||
|
// ClientID 是 visionA 在 Member Center 註冊的 OAuth client_id(confidential 或 public 皆可)。
|
||||||
|
ClientID string
|
||||||
|
|
||||||
|
// ClientSecret 是 confidential client 的 secret;不可外洩到 frontend / log。
|
||||||
|
//
|
||||||
|
// A1(2026-05-01):ClientSecret 為**選填**:
|
||||||
|
// - 有值 → confidential client mode(client_secret + PKCE 雙保險)
|
||||||
|
// - 留空 → public PKCE-only client mode(純依靠 PKCE 防 code interception)
|
||||||
|
// 兩種 mode 都符合 OAuth 2.1,由 IdP 註冊 client 時決定。
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// RedirectURL 是 visionA-backend 的 callback URL,例如
|
||||||
|
// http://localhost:8080/api/auth/callback(dev)或
|
||||||
|
// https://app.visiona.cloud/api/auth/callback(prod)。
|
||||||
|
// 必須與在 Member Center 註冊的 redirect_uri 完全一致。
|
||||||
|
RedirectURL string
|
||||||
|
|
||||||
|
// Scopes 是 OIDC scope 清單,預設 ["openid", "email", "profile"]。
|
||||||
|
// 若為空,NewProvider 會套用預設值。
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultScopes 是 OIDC 標準 scope 集合,能取得 sub / email / name 三個 claim。
|
||||||
|
// 對齊 oidc-tdd.md §7.3 的 Claim Mapping。
|
||||||
|
var DefaultScopes = []string{"openid", "email", "profile"}
|
||||||
|
|
||||||
|
// Provider 是本 package 對外的唯一 interface,封裝 OIDC Authorization Code + PKCE 流程。
|
||||||
|
//
|
||||||
|
// 設計理由:以 interface 為公開 API,內部實作(目前以 coreos/go-oidc/v3 為基礎)可未來替換
|
||||||
|
// 而不影響 caller(OB3 / OB4 的 OIDCAuthService 與 auth handler)。
|
||||||
|
//
|
||||||
|
// 所有方法都應是 goroutine-safe:底層 coreos provider 與 oauth2.Config 皆為 immutable,
|
||||||
|
// JWKS / discovery 快取由 coreos lib 內部以 RWMutex 保護。
|
||||||
|
type Provider interface {
|
||||||
|
// AuthorizationURL 組出讓 user 跳轉到 IdP 登入畫面的 URL。
|
||||||
|
//
|
||||||
|
// 三個隨機值由 caller(通常是 auth handler)以 GenerateState/Nonce/CodeVerifier 產生並存
|
||||||
|
// pending session;CodeChallenge 是 CodeVerifier 經 SHA256+base64url 後的值。
|
||||||
|
//
|
||||||
|
// 回傳的 URL 已含 response_type=code、scope、PKCE、state、nonce 參數。
|
||||||
|
AuthorizationURL(state, nonce, codeChallenge string) string
|
||||||
|
|
||||||
|
// ExchangeCode 用 authorization code + code_verifier 向 token endpoint 換 token set。
|
||||||
|
//
|
||||||
|
// 錯誤對應:
|
||||||
|
// - 401 / invalid_grant → ErrInvalidGrant(code 用過 / 過期 / verifier 不符)
|
||||||
|
// - 其他 4xx/5xx / 網路錯誤 → ErrTokenExchange(包 inner error)
|
||||||
|
ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error)
|
||||||
|
|
||||||
|
// VerifyIDToken 驗 id_token 簽章與必驗 claim:
|
||||||
|
// - 簽章:以 JWKS 對應 kid 的 public key 驗 RS256(簽章演算法由 IdP 決定,
|
||||||
|
// coreos lib 預設信任 IdP discovery 宣告的 id_token_signing_alg_values_supported)
|
||||||
|
// - iss == cfg.IssuerURL
|
||||||
|
// - aud 包含 cfg.ClientID
|
||||||
|
// - exp > now(含預設 leeway)
|
||||||
|
// - nonce == expectedNonce
|
||||||
|
//
|
||||||
|
// 錯誤對應:簽章/iss/aud/exp 失敗回 ErrInvalidIDToken(包 inner),
|
||||||
|
// nonce 不符回 ErrInvalidNonce。
|
||||||
|
VerifyIDToken(ctx context.Context, rawIDToken, expectedNonce string) (*Claims, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenResponse 是 token endpoint 回傳的 token set。
|
||||||
|
//
|
||||||
|
// 對齊 RFC 6749 §5.1 + OpenID Connect Core §3.1.3.3。
|
||||||
|
// 不包含 IdToken 以外的 raw JWT(已分別放欄位),caller 拿到後通常會:
|
||||||
|
// 1. 把 IDToken 餵給 VerifyIDToken 拿 claims
|
||||||
|
// 2. 把 AccessToken 存進 server-side session(visionA BFF 模式不交給 frontend)
|
||||||
|
// 3. 雛形 Phase 0.6 不用 RefreshToken(見 ADR-010 §「負面影響」)
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string
|
||||||
|
IDToken string
|
||||||
|
RefreshToken string
|
||||||
|
TokenType string // 預期固定 "Bearer"
|
||||||
|
ExpiresIn int // access_token 有效秒數(IdP 指定)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claims 是 id_token 驗證通過後解出來的標準 + 自定 claim。
|
||||||
|
//
|
||||||
|
// Subject / Email / Name 對齊 oidc-tdd.md §7.3 的 Claim Mapping,
|
||||||
|
// 後續 OB3 的 OIDCAuthService 會以這三個欄位建 user session。
|
||||||
|
//
|
||||||
|
// Raw 保留底層 lib 解出的完整 claim map,未來若需要 picture / preferred_username
|
||||||
|
// 等額外欄位,可從 Raw 取出而不需要改 Claims struct。
|
||||||
|
type Claims struct {
|
||||||
|
Subject string // OIDC sub
|
||||||
|
Email string // OIDC email(scope=email)
|
||||||
|
Name string // OIDC name(scope=profile)
|
||||||
|
Issuer string // iss
|
||||||
|
Audience string // aud(取第一個 audience;OIDC 多 aud 時 Member Center 不使用)
|
||||||
|
IssuedAt time.Time // iat
|
||||||
|
ExpiresAt time.Time // exp
|
||||||
|
Nonce string // nonce
|
||||||
|
Raw map[string]any
|
||||||
|
}
|
||||||
77
visionA-backend/internal/oidc/pkce.go
Normal file
77
visionA-backend/internal/oidc/pkce.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// 隨機值長度常數(位元組數)。
|
||||||
|
//
|
||||||
|
// RFC 7636 §4.1 規定 code_verifier 是 43-128 個字元([A-Z a-z 0-9 - . _ ~])。
|
||||||
|
// 32 bytes 經 base64url(無 padding)編碼後 = ceil(32 * 4 / 3) = 43 字元,
|
||||||
|
// 剛好等於最小邊界 43,落在合規範圍內,且提供 256 bits 熵。
|
||||||
|
//
|
||||||
|
// state 與 nonce 沒有 RFC 規範長度,採同樣 32 bytes(256 bits)足以抵抗暴力猜測。
|
||||||
|
const (
|
||||||
|
// pkceVerifierBytes = 32 bytes → base64url 後 43 字元,符合 RFC 7636 範圍下界。
|
||||||
|
pkceVerifierBytes = 32
|
||||||
|
|
||||||
|
// stateNonceBytes = 32 bytes → 256 bits 隨機性。
|
||||||
|
stateNonceBytes = 32
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateCodeVerifier 產生 RFC 7636 PKCE code_verifier。
|
||||||
|
//
|
||||||
|
// 回傳值是 base64url(無 padding)編碼的 32 byte 隨機值,等於 43 個字元,
|
||||||
|
// 落在 RFC 7636 §4.1 規定的 43-128 字元範圍內(剛好踩在下界)。
|
||||||
|
//
|
||||||
|
// 字元集為 [A-Z a-z 0-9 - _](base64url 規範),符合 RFC 7636 §4.1 的
|
||||||
|
// unreserved characters 子集(少了 `.` 和 `~`,但仍合規)。
|
||||||
|
func GenerateCodeVerifier() (string, error) {
|
||||||
|
return randomBase64URL(pkceVerifierBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodeChallenge 由 code_verifier 算出 RFC 7636 §4.2 規定的 code_challenge:
|
||||||
|
//
|
||||||
|
// BASE64URL-ENCODE(SHA256(ASCII(code_verifier)))
|
||||||
|
//
|
||||||
|
// challenge_method 固定 S256(不支援 plain,因 OAuth 2.1 已標 plain 為 deprecated)。
|
||||||
|
//
|
||||||
|
// caller 應將回傳值放在 authorization request 的 code_challenge 參數,
|
||||||
|
// 並把 GenerateCodeVerifier() 的原值(verifier)安全地存在 server-side pending session,
|
||||||
|
// 後續 ExchangeCode 時帶 verifier 給 token endpoint 完成 proof。
|
||||||
|
func CodeChallenge(verifier string) string {
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateState 產生 OAuth 2.0 state 值,用於 CSRF 防護。
|
||||||
|
//
|
||||||
|
// caller 應將其存在 server-side pending session,並在 callback 收到後比對。
|
||||||
|
func GenerateState() (string, error) {
|
||||||
|
return randomBase64URL(stateNonceBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateNonce 產生 OIDC nonce 值,用於 id_token replay 防護。
|
||||||
|
//
|
||||||
|
// caller 應將其存在 server-side pending session,並在 VerifyIDToken 時比對 claims.Nonce。
|
||||||
|
func GenerateNonce() (string, error) {
|
||||||
|
return randomBase64URL(stateNonceBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomBase64URL 是內部 helper:產生 n bytes 隨機值並做 base64url(無 padding)編碼。
|
||||||
|
//
|
||||||
|
// 使用 crypto/rand 為密碼學安全亂數源;任何讀取錯誤都向上傳遞給 caller 處理
|
||||||
|
// (通常代表系統 entropy 出問題,應該讓請求 fail 而非回退到不安全的預設值)。
|
||||||
|
func randomBase64URL(n int) (string, error) {
|
||||||
|
if n <= 0 {
|
||||||
|
return "", fmt.Errorf("oidc: random length must be positive, got %d", n)
|
||||||
|
}
|
||||||
|
b := make([]byte, n)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", fmt.Errorf("oidc: read random bytes: %w", err)
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
106
visionA-backend/internal/oidc/pkce_test.go
Normal file
106
visionA-backend/internal/oidc/pkce_test.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// base64url(無 padding)允許的字元集,對齊 RFC 4648 §5。
|
||||||
|
// 順帶涵蓋 RFC 7636 §4.1 對 code_verifier 的字元集要求子集。
|
||||||
|
var base64URLPattern = regexp.MustCompile(`^[A-Za-z0-9_-]+$`)
|
||||||
|
|
||||||
|
func TestGenerateCodeVerifier_LengthAndCharset(t *testing.T) {
|
||||||
|
v, err := GenerateCodeVerifier()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 32 bytes → base64url 後 43 字元(無 padding)。
|
||||||
|
// RFC 7636 規定範圍 43-128 字元;43 字元剛好符合最小邊界。
|
||||||
|
assert.Len(t, v, 43, "verifier 應為 43 字元(32 bytes base64url)")
|
||||||
|
assert.GreaterOrEqual(t, len(v), 43, "RFC 7636 最小 43 字元")
|
||||||
|
assert.LessOrEqual(t, len(v), 128, "RFC 7636 最大 128 字元")
|
||||||
|
|
||||||
|
assert.Regexp(t, base64URLPattern, v, "verifier 應只含 base64url 字元")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateCodeVerifier_Randomness(t *testing.T) {
|
||||||
|
const n = 50
|
||||||
|
seen := make(map[string]struct{}, n)
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
v, err := GenerateCodeVerifier()
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, dup := seen[v]
|
||||||
|
assert.Falsef(t, dup, "第 %d 次產生與先前重複,亂數源異常", i)
|
||||||
|
seen[v] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodeChallenge_KnownVector(t *testing.T) {
|
||||||
|
// RFC 7636 Appendix B 提供的 known answer test:
|
||||||
|
// verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
// challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||||
|
const verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||||
|
const want = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||||
|
|
||||||
|
got := CodeChallenge(verifier)
|
||||||
|
assert.Equal(t, want, got, "challenge 與 RFC 7636 Appendix B test vector 不符")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCodeChallenge_MatchesSHA256(t *testing.T) {
|
||||||
|
v, err := GenerateCodeVerifier()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
want := base64.RawURLEncoding.EncodeToString(sha256Sum([]byte(v)))
|
||||||
|
got := CodeChallenge(v)
|
||||||
|
assert.Equal(t, want, got, "challenge 應為 base64url(SHA256(verifier))")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateState_Format(t *testing.T) {
|
||||||
|
s, err := GenerateState()
|
||||||
|
require.NoError(t, err)
|
||||||
|
// 32 bytes → 43 字元 base64url(無 padding)
|
||||||
|
assert.Len(t, s, 43)
|
||||||
|
assert.Regexp(t, base64URLPattern, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateNonce_Format(t *testing.T) {
|
||||||
|
n, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, n, 43)
|
||||||
|
assert.Regexp(t, base64URLPattern, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStateAndNonce_Independent(t *testing.T) {
|
||||||
|
// state 和 nonce 雖然產生方式相同,但兩次連續呼叫不應產生相同值。
|
||||||
|
s1, err := GenerateState()
|
||||||
|
require.NoError(t, err)
|
||||||
|
s2, err := GenerateState()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, s1, s2, "兩次 GenerateState 不應重複")
|
||||||
|
|
||||||
|
n1, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
|
n2, err := GenerateNonce()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEqual(t, n1, n2, "兩次 GenerateNonce 不應重複")
|
||||||
|
|
||||||
|
assert.NotEqual(t, s1, n1, "state 與 nonce 應為獨立隨機值")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRandomBase64URL_RejectNonPositive(t *testing.T) {
|
||||||
|
_, err := randomBase64URL(0)
|
||||||
|
assert.Error(t, err, "n=0 應拒絕")
|
||||||
|
|
||||||
|
_, err = randomBase64URL(-1)
|
||||||
|
assert.Error(t, err, "n<0 應拒絕")
|
||||||
|
}
|
||||||
|
|
||||||
|
// sha256Sum 是 test helper,避免在測試中每次都寫 [:]。
|
||||||
|
func sha256Sum(b []byte) []byte {
|
||||||
|
s := sha256.Sum256(b)
|
||||||
|
return s[:]
|
||||||
|
}
|
||||||
270
visionA-backend/internal/oidc/provider.go
Normal file
270
visionA-backend/internal/oidc/provider.go
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
coreosoidc "github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// provider 是 Provider interface 的預設實作,底層用 coreos/go-oidc/v3。
|
||||||
|
//
|
||||||
|
// 為什麼選 coreos/go-oidc/v3:
|
||||||
|
// - 業界事實標準,廣泛採用、長期維護
|
||||||
|
// - 自動處理 discovery(/.well-known/openid-configuration)
|
||||||
|
// - 自動處理 JWKS 抓取與快取(內建 1h refresh)
|
||||||
|
// - 與 golang.org/x/oauth2 標準 OAuth2 lib 整合無縫
|
||||||
|
// - id_token 驗證涵蓋 iss / aud / exp / 簽章;nonce 與額外 claim 由我們補上
|
||||||
|
//
|
||||||
|
// 為什麼仍包一層 wrapper:
|
||||||
|
// - 公開 API 在我們手上,未來若要換 lib(例如自刻或換 lestrrat-go/jwx)caller 不受影響
|
||||||
|
// - 集中錯誤型別轉換(coreos 各種錯誤 → 我們的 sentinel errors)
|
||||||
|
// - 集中 nonce 比對(coreos 預設不驗 nonce,留給 caller 處理)
|
||||||
|
type provider struct {
|
||||||
|
cfg ProviderConfig
|
||||||
|
oauth2Cfg *oauth2.Config
|
||||||
|
idTokenVerif *coreosoidc.IDTokenVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider 以 cfg 建立一個 Provider 實例。
|
||||||
|
//
|
||||||
|
// 過程:
|
||||||
|
// 1. 驗 cfg 必填欄位
|
||||||
|
// 2. 用 coreos lib 抓 discovery(含 jwks_uri、authorization_endpoint、token_endpoint)
|
||||||
|
// — 此步驟有網路 I/O,會以 ctx 控制 timeout
|
||||||
|
// 3. 建 oauth2.Config(後續 ExchangeCode / AuthorizationURL 會用到)
|
||||||
|
// 4. 建 IDTokenVerifier(內部會持有 JWKS 快取,自動 refresh)
|
||||||
|
//
|
||||||
|
// caller 通常在程式啟動時呼叫一次,存在 long-lived 的 Deps 容器中重複使用。
|
||||||
|
// 若 IdP 不可達會回 ErrDiscoveryFetch(包 inner error)— 啟動時 fail-fast。
|
||||||
|
func NewProvider(ctx context.Context, cfg ProviderConfig) (Provider, error) {
|
||||||
|
if err := validateConfig(&cfg); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
coreosProv, err := coreosoidc.NewProvider(ctx, cfg.IssuerURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: %v", ErrDiscoveryFetch, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A1:ClientSecret 留空 → public PKCE-only client mode。
|
||||||
|
//
|
||||||
|
// oauth2 lib 的 token request 行為(golang.org/x/oauth2 v0.36 internal/token.go):
|
||||||
|
//
|
||||||
|
// - AuthStyleInParams:clientID / clientSecret 寫進 POST form。空 secret 時
|
||||||
|
// `if clientSecret != ""` 判斷成立 → **完全不送 client_secret 欄位**,
|
||||||
|
// 符合 RFC 6749 §2.3.1 對 public client 的規範。
|
||||||
|
// - AuthStyleInHeader:永遠 SetBasicAuth(clientID, clientSecret) → 即使空
|
||||||
|
// secret 也會送 `Authorization: Basic base64(clientID:)`,多數 IdP 會把這個
|
||||||
|
// 視為「confidential client 但 secret 錯」而 401。
|
||||||
|
// - AuthStyleAutoDetect(zero value):第一輪試 InHeader,4xx 後 fallback 到
|
||||||
|
// InParams。對 public client 多了一次失敗 round-trip。
|
||||||
|
//
|
||||||
|
// 所以 public client mode 強制 InParams,跳過 InHeader 探測;
|
||||||
|
// confidential client mode 維持 AutoDetect(沿用 lib 預設行為,與 OB1 一致)。
|
||||||
|
endpoint := coreosProv.Endpoint()
|
||||||
|
if cfg.ClientSecret == "" {
|
||||||
|
endpoint.AuthStyle = oauth2.AuthStyleInParams
|
||||||
|
}
|
||||||
|
|
||||||
|
oauth2Cfg := &oauth2.Config{
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
ClientSecret: cfg.ClientSecret, // 空字串 → token endpoint 不送 client_secret
|
||||||
|
RedirectURL: cfg.RedirectURL,
|
||||||
|
Endpoint: endpoint,
|
||||||
|
Scopes: cfg.Scopes,
|
||||||
|
}
|
||||||
|
|
||||||
|
verifier := coreosProv.Verifier(&coreosoidc.Config{
|
||||||
|
ClientID: cfg.ClientID,
|
||||||
|
})
|
||||||
|
|
||||||
|
return &provider{
|
||||||
|
cfg: cfg,
|
||||||
|
oauth2Cfg: oauth2Cfg,
|
||||||
|
idTokenVerif: verifier,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateConfig 檢查 ProviderConfig 必填欄位,並套用預設 Scopes。
|
||||||
|
//
|
||||||
|
// A1(2026-05-01):ClientSecret 為**選填**,留空時走 public PKCE-only client mode。
|
||||||
|
// 必填欄位剩 IssuerURL / ClientID / RedirectURL。
|
||||||
|
//
|
||||||
|
// 注意:cfg 是 *指標*,會被就地修改(套預設 Scopes)。這是有意為之 —
|
||||||
|
// caller 通常從 env 載入 ProviderConfig 一次性傳入,套預設後立刻被 NewProvider 拷貝進
|
||||||
|
// internal struct,不會有別名問題。
|
||||||
|
func validateConfig(cfg *ProviderConfig) error {
|
||||||
|
missing := make([]string, 0, 3)
|
||||||
|
if strings.TrimSpace(cfg.IssuerURL) == "" {
|
||||||
|
missing = append(missing, "IssuerURL")
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(cfg.ClientID) == "" {
|
||||||
|
missing = append(missing, "ClientID")
|
||||||
|
}
|
||||||
|
// ClientSecret 不檢查(A1:public PKCE-only client 留空合法)。
|
||||||
|
if strings.TrimSpace(cfg.RedirectURL) == "" {
|
||||||
|
missing = append(missing, "RedirectURL")
|
||||||
|
}
|
||||||
|
if len(missing) > 0 {
|
||||||
|
return fmt.Errorf("%w: missing required fields: %s",
|
||||||
|
ErrInvalidConfig, strings.Join(missing, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssuerURL 必須是合法 URL;coreos lib 會再驗一次但訊息較不友善。
|
||||||
|
if _, err := url.Parse(cfg.IssuerURL); err != nil {
|
||||||
|
return fmt.Errorf("%w: IssuerURL invalid: %v", ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
if _, err := url.Parse(cfg.RedirectURL); err != nil {
|
||||||
|
return fmt.Errorf("%w: RedirectURL invalid: %v", ErrInvalidConfig, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cfg.Scopes) == 0 {
|
||||||
|
// 套預設值;不深拷貝,因為 DefaultScopes 不會被修改。
|
||||||
|
cfg.Scopes = DefaultScopes
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizationURL 實作 Provider.AuthorizationURL。
|
||||||
|
//
|
||||||
|
// 用 oauth2.Config.AuthCodeURL 組 URL,加上 PKCE 與 nonce 兩個額外參數
|
||||||
|
// (oauth2 lib 原生不知道這兩個東西,需以 oauth2.SetAuthURLParam 注入)。
|
||||||
|
func (p *provider) AuthorizationURL(state, nonce, codeChallenge string) string {
|
||||||
|
return p.oauth2Cfg.AuthCodeURL(
|
||||||
|
state,
|
||||||
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
|
oauth2.SetAuthURLParam("nonce", nonce),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCode 實作 Provider.ExchangeCode。
|
||||||
|
//
|
||||||
|
// 把 code_verifier 注入 token request,由 IdP 驗 PKCE proof。
|
||||||
|
//
|
||||||
|
// 錯誤分類邏輯:
|
||||||
|
// - oauth2 回的 *oauth2.RetrieveError 如果 ErrorCode == "invalid_grant"
|
||||||
|
// → ErrInvalidGrant(典型情境:code 已用過、過期、verifier 不符)
|
||||||
|
// - 其他 → ErrTokenExchange + 包 inner error(如 IdP 5xx、connection refused)
|
||||||
|
func (p *provider) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) {
|
||||||
|
tok, err := p.oauth2Cfg.Exchange(
|
||||||
|
ctx, code,
|
||||||
|
oauth2.SetAuthURLParam("code_verifier", codeVerifier),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, classifyExchangeError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken, ok := tok.Extra("id_token").(string)
|
||||||
|
if !ok || rawIDToken == "" {
|
||||||
|
// IdP 回了 200 但沒給 id_token — 違反 OIDC spec
|
||||||
|
return nil, fmt.Errorf("%w: id_token missing from token response", ErrTokenExchange)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresIn := 0
|
||||||
|
if !tok.Expiry.IsZero() {
|
||||||
|
expiresIn = int(time.Until(tok.Expiry).Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
return &TokenResponse{
|
||||||
|
AccessToken: tok.AccessToken,
|
||||||
|
IDToken: rawIDToken,
|
||||||
|
RefreshToken: tok.RefreshToken,
|
||||||
|
TokenType: tok.TokenType,
|
||||||
|
ExpiresIn: expiresIn,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyExchangeError 把 oauth2 lib 的錯誤對應到我們的 sentinel error。
|
||||||
|
//
|
||||||
|
// oauth2.RetrieveError 在新版 lib 中是公開型別;它的 ErrorCode 對應 RFC 6749 §5.2 的
|
||||||
|
// error 欄位,invalid_grant 是 PKCE/code 失敗最常見的錯誤碼。
|
||||||
|
func classifyExchangeError(err error) error {
|
||||||
|
var retrieveErr *oauth2.RetrieveError
|
||||||
|
if errors.As(err, &retrieveErr) {
|
||||||
|
if retrieveErr.ErrorCode == "invalid_grant" {
|
||||||
|
return fmt.Errorf("%w: %v", ErrInvalidGrant, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%w: %v", ErrTokenExchange, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyIDToken 實作 Provider.VerifyIDToken。
|
||||||
|
//
|
||||||
|
// coreos verifier 自動驗:簽章、iss、aud、exp(含預設 leeway)。
|
||||||
|
// 我們在外層補:
|
||||||
|
// - nonce 比對(caller 帶 expectedNonce)
|
||||||
|
// - claim 解析成我們自己的 Claims struct
|
||||||
|
// - 錯誤型別轉換(coreos 訊息 → 我們的 sentinel)
|
||||||
|
func (p *provider) VerifyIDToken(ctx context.Context, rawIDToken, expectedNonce string) (*Claims, error) {
|
||||||
|
if rawIDToken == "" {
|
||||||
|
return nil, fmt.Errorf("%w: empty id_token", ErrInvalidIDToken)
|
||||||
|
}
|
||||||
|
if expectedNonce == "" {
|
||||||
|
// nonce 是 OIDC replay 防護的核心,caller 必須提供 — 強制 fail 而非 silently skip
|
||||||
|
return nil, fmt.Errorf("%w: expectedNonce is required", ErrInvalidNonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := p.idTokenVerif.Verify(ctx, rawIDToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, classifyVerifyError(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解出標準 + 自定 claim。coreos IDToken 已驗 iss/aud/exp/簽章,
|
||||||
|
// 我們再補 nonce 比對。
|
||||||
|
var raw map[string]any
|
||||||
|
if err := idToken.Claims(&raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("%w: parse claims: %v", ErrInvalidIDToken, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// nonce 比對(coreos 不會驗,因為它無法知道 expected 值)
|
||||||
|
tokenNonce, _ := raw["nonce"].(string)
|
||||||
|
if tokenNonce != expectedNonce {
|
||||||
|
return nil, ErrInvalidNonce
|
||||||
|
}
|
||||||
|
|
||||||
|
email, _ := raw["email"].(string)
|
||||||
|
name, _ := raw["name"].(string)
|
||||||
|
|
||||||
|
// audience:coreos 已驗 aud 包含 ClientID,我們選 ClientID 作為「使用中的 audience」
|
||||||
|
// 而非從 raw 取第一個 — 後者在 multi-aud 場景會誤導。
|
||||||
|
audience := p.cfg.ClientID
|
||||||
|
|
||||||
|
return &Claims{
|
||||||
|
Subject: idToken.Subject,
|
||||||
|
Email: email,
|
||||||
|
Name: name,
|
||||||
|
Issuer: idToken.Issuer,
|
||||||
|
Audience: audience,
|
||||||
|
IssuedAt: idToken.IssuedAt,
|
||||||
|
ExpiresAt: idToken.Expiry,
|
||||||
|
Nonce: tokenNonce,
|
||||||
|
Raw: raw,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// classifyVerifyError 把 coreos verifier 的錯誤轉成我們的 sentinel error。
|
||||||
|
//
|
||||||
|
// coreos lib 沒有 typed error(除了少數例外),所以以字串 contains 判斷。
|
||||||
|
// 這雖然脆弱(lib 升級可能改訊息),但符合事實上的慣例;
|
||||||
|
// 真要嚴謹可以改用 errors.As 看 coreos 內部 type,但訊息穩定性目前 OK。
|
||||||
|
func classifyVerifyError(err error) error {
|
||||||
|
msg := err.Error()
|
||||||
|
switch {
|
||||||
|
case strings.Contains(msg, "expired"):
|
||||||
|
return fmt.Errorf("%w: %v", ErrTokenExpired, err)
|
||||||
|
case strings.Contains(msg, "issuer"):
|
||||||
|
return fmt.Errorf("%w: %v", ErrInvalidIssuer, err)
|
||||||
|
case strings.Contains(msg, "audience"):
|
||||||
|
return fmt.Errorf("%w: %v", ErrInvalidAudience, err)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("%w: %v", ErrInvalidIDToken, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
737
visionA-backend/internal/oidc/provider_test.go
Normal file
737
visionA-backend/internal/oidc/provider_test.go
Normal file
@ -0,0 +1,737 @@
|
|||||||
|
package oidc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeOIDC 是一個用 httptest 起來的最小化 OIDC 模擬器:
|
||||||
|
// - GET /.well-known/openid-configuration → discovery doc
|
||||||
|
// - GET /jwks → JWKS(含 1 把 RSA public key)
|
||||||
|
// - POST /token → 接 authorization_code,回 token set
|
||||||
|
//
|
||||||
|
// 簽 id_token 用 go-jose(coreos/go-oidc 的內部依賴,已在 go.sum 中),不需引額外 lib。
|
||||||
|
type fakeOIDC struct {
|
||||||
|
server *httptest.Server
|
||||||
|
signingKey *rsa.PrivateKey
|
||||||
|
keyID string
|
||||||
|
clientID string
|
||||||
|
|
||||||
|
// 可由各測試修改的「下一個 token 行為」控制旗標
|
||||||
|
mu chan struct{} // 簡單以 buffered chan 當 mutex(避免 import sync)
|
||||||
|
expectVerifier string // POST /token 時驗 code_verifier 是否相符;空字串=不驗
|
||||||
|
respondCode int // POST /token 回應 status code(0 = 200)
|
||||||
|
respondBody string // 非空時直接回此 body 取代正常 token response
|
||||||
|
idTokenClaims jwt.Claims // 自訂簽 token 的 standard claims(zero = 預設)
|
||||||
|
idTokenExtra map[string]any
|
||||||
|
idTokenAlg jose.SignatureAlgorithm // 預設 RS256
|
||||||
|
skipIDToken bool // true 時 token response 不含 id_token
|
||||||
|
|
||||||
|
// 觀測:最後一次 /token 收到的 form / Authorization header(A1 加:驗 public client mode)
|
||||||
|
lastTokenForm url.Values
|
||||||
|
lastTokenAuthHdr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeOIDC(t *testing.T, clientID string) *fakeOIDC {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err, "產 RSA key 失敗")
|
||||||
|
|
||||||
|
f := &fakeOIDC{
|
||||||
|
signingKey: priv,
|
||||||
|
keyID: "test-key-1",
|
||||||
|
clientID: clientID,
|
||||||
|
mu: make(chan struct{}, 1),
|
||||||
|
idTokenAlg: jose.RS256,
|
||||||
|
}
|
||||||
|
f.mu <- struct{}{} // 初始化「鎖可用」
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
// discovery doc 必須在 server 起來後才知道 issuer URL,所以用 closure 延遲組
|
||||||
|
mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
base := "http://" + r.Host
|
||||||
|
doc := map[string]any{
|
||||||
|
"issuer": base,
|
||||||
|
"authorization_endpoint": base + "/authorize",
|
||||||
|
"token_endpoint": base + "/token",
|
||||||
|
"jwks_uri": base + "/jwks",
|
||||||
|
"response_types_supported": []string{
|
||||||
|
"code",
|
||||||
|
},
|
||||||
|
"id_token_signing_alg_values_supported": []string{"RS256"},
|
||||||
|
"subject_types_supported": []string{"public"},
|
||||||
|
"scopes_supported": []string{"openid", "email", "profile"},
|
||||||
|
}
|
||||||
|
writeJSON(w, doc)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
jwks := map[string]any{
|
||||||
|
"keys": []map[string]any{
|
||||||
|
rsaPublicKeyToJWK(&priv.PublicKey, f.keyID),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
writeJSON(w, jwks)
|
||||||
|
})
|
||||||
|
|
||||||
|
mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
f.handleToken(t, w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
f.server = httptest.NewServer(mux)
|
||||||
|
t.Cleanup(f.server.Close)
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDC) issuer() string { return f.server.URL }
|
||||||
|
|
||||||
|
// withState 是 helper:在 chan-as-mutex 保護下安全修改控制旗標。
|
||||||
|
func (f *fakeOIDC) withState(fn func()) {
|
||||||
|
<-f.mu
|
||||||
|
defer func() { f.mu <- struct{}{} }()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
// snapshot 是 helper:原子讀取所有控制旗標。
|
||||||
|
func (f *fakeOIDC) snapshot() (verifier string, code int, body string, claims jwt.Claims, extra map[string]any, alg jose.SignatureAlgorithm, skipID bool) {
|
||||||
|
<-f.mu
|
||||||
|
defer func() { f.mu <- struct{}{} }()
|
||||||
|
return f.expectVerifier, f.respondCode, f.respondBody, f.idTokenClaims, f.idTokenExtra, f.idTokenAlg, f.skipIDToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeOIDC) handleToken(t *testing.T, w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "bad form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 觀測:抄一份 form / Authorization header 給測試驗 public vs confidential mode
|
||||||
|
f.withState(func() {
|
||||||
|
f.lastTokenForm = make(url.Values, len(r.PostForm))
|
||||||
|
for k, vv := range r.PostForm {
|
||||||
|
f.lastTokenForm[k] = append([]string(nil), vv...)
|
||||||
|
}
|
||||||
|
f.lastTokenAuthHdr = r.Header.Get("Authorization")
|
||||||
|
})
|
||||||
|
|
||||||
|
expectVerifier, code, body, claims, extra, alg, skipID := f.snapshot()
|
||||||
|
|
||||||
|
// caller 可指定強制錯誤 / 自訂 body
|
||||||
|
if body != "" {
|
||||||
|
if code == 0 {
|
||||||
|
code = http.StatusBadRequest
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(code)
|
||||||
|
_, _ = w.Write([]byte(body))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectVerifier != "" && r.Form.Get("code_verifier") != expectVerifier {
|
||||||
|
// PKCE proof 失敗:回 RFC 6749 §5.2 規範的 invalid_grant
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
_, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"PKCE verifier mismatch"}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 簽 id_token
|
||||||
|
now := time.Now()
|
||||||
|
if claims.Issuer == "" {
|
||||||
|
claims = jwt.Claims{
|
||||||
|
Issuer: f.issuer(),
|
||||||
|
Subject: "sub-fake-user-001",
|
||||||
|
Audience: jwt.Audience{f.clientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
NotBefore: jwt.NewNumericDate(now),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if extra == nil {
|
||||||
|
extra = map[string]any{
|
||||||
|
"email": "fake-user@example.com",
|
||||||
|
"name": "Fake User",
|
||||||
|
"nonce": r.Form.Get("__test_nonce_passthrough"), // 不會由真 OIDC server 帶;測試以另路徑注入
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rawIDToken, err := signIDToken(f.signingKey, f.keyID, alg, claims, extra)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "sign error: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"access_token": "fake-access-token",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}
|
||||||
|
if !skipID {
|
||||||
|
resp["id_token"] = rawIDToken
|
||||||
|
}
|
||||||
|
writeJSON(w, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// signIDToken 用 RSA private key 簽一個 OIDC id_token(JWS / RS256)。
|
||||||
|
func signIDToken(priv *rsa.PrivateKey, kid string, alg jose.SignatureAlgorithm, std jwt.Claims, extra map[string]any) (string, error) {
|
||||||
|
signerOpts := (&jose.SignerOptions{}).WithType("JWT")
|
||||||
|
signerOpts.WithHeader("kid", kid)
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: priv}, signerOpts)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
builder := jwt.Signed(signer).Claims(std)
|
||||||
|
if len(extra) > 0 {
|
||||||
|
builder = builder.Claims(extra)
|
||||||
|
}
|
||||||
|
return builder.Serialize()
|
||||||
|
}
|
||||||
|
|
||||||
|
// rsaPublicKeyToJWK 把 RSA public key 編成 JWKS spec 的 key 物件。
|
||||||
|
func rsaPublicKeyToJWK(pub *rsa.PublicKey, kid string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"kty": "RSA",
|
||||||
|
"alg": "RS256",
|
||||||
|
"use": "sig",
|
||||||
|
"kid": kid,
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString(bigIntBytes(pub.E)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bigIntBytes(e int) []byte {
|
||||||
|
// RSA exponent 通常是 65537(0x010001)= 3 bytes。手動編碼。
|
||||||
|
out := []byte{}
|
||||||
|
for e > 0 {
|
||||||
|
out = append([]byte{byte(e & 0xff)}, out...)
|
||||||
|
e >>= 8
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
out = []byte{0}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================================================================
|
||||||
|
// 測試
|
||||||
|
// =====================================================================
|
||||||
|
|
||||||
|
const (
|
||||||
|
testClientID = "visiona-backend-test"
|
||||||
|
testClientSecret = "test-secret"
|
||||||
|
testRedirect = "http://localhost:8080/api/auth/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newProviderForTest(t *testing.T, fake *fakeOIDC) Provider {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
p, err := NewProvider(ctx, ProviderConfig{
|
||||||
|
IssuerURL: fake.issuer(),
|
||||||
|
ClientID: testClientID,
|
||||||
|
ClientSecret: testClientSecret,
|
||||||
|
RedirectURL: testRedirect,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_Discovery(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
require.NotNil(t, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_RejectInvalidConfig(t *testing.T) {
|
||||||
|
// A1(2026-05-01):ClientSecret 為選填,因此 "missing secret" case 已移除。
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
cfg ProviderConfig
|
||||||
|
}{
|
||||||
|
{"missing issuer", ProviderConfig{ClientID: "x", ClientSecret: "y", RedirectURL: "http://z"}},
|
||||||
|
{"missing client id", ProviderConfig{IssuerURL: "http://x", ClientSecret: "y", RedirectURL: "http://z"}},
|
||||||
|
{"missing redirect", ProviderConfig{IssuerURL: "http://x", ClientID: "y", ClientSecret: "z"}},
|
||||||
|
}
|
||||||
|
for _, c := range cases {
|
||||||
|
t.Run(c.name, func(t *testing.T) {
|
||||||
|
_, err := NewProvider(context.Background(), c.cfg)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidConfig)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProvider_DiscoveryFailure(t *testing.T) {
|
||||||
|
// 用一個立刻關掉的 server 模擬 IdP 不可達
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
_, err := NewProvider(context.Background(), ProviderConfig{
|
||||||
|
IssuerURL: srv.URL,
|
||||||
|
ClientID: "c",
|
||||||
|
ClientSecret: "s",
|
||||||
|
RedirectURL: "http://r",
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrDiscoveryFetch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAuthorizationURL_Format(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
state, _ := GenerateState()
|
||||||
|
nonce, _ := GenerateNonce()
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
challenge := CodeChallenge(verifier)
|
||||||
|
|
||||||
|
raw := p.AuthorizationURL(state, nonce, challenge)
|
||||||
|
u, err := url.Parse(raw)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
q := u.Query()
|
||||||
|
assert.Equal(t, "code", q.Get("response_type"))
|
||||||
|
assert.Equal(t, testClientID, q.Get("client_id"))
|
||||||
|
assert.Equal(t, testRedirect, q.Get("redirect_uri"))
|
||||||
|
assert.Equal(t, state, q.Get("state"))
|
||||||
|
assert.Equal(t, nonce, q.Get("nonce"))
|
||||||
|
assert.Equal(t, challenge, q.Get("code_challenge"))
|
||||||
|
assert.Equal(t, "S256", q.Get("code_challenge_method"))
|
||||||
|
|
||||||
|
// scope 應含 openid email profile(順序由 oauth2 lib 決定,用 contains 驗)
|
||||||
|
scope := q.Get("scope")
|
||||||
|
for _, s := range []string{"openid", "email", "profile"} {
|
||||||
|
assert.Truef(t, strings.Contains(scope, s), "scope 應含 %q,得 %q", s, scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
// authorization_endpoint 應指向 fake server 的 /authorize
|
||||||
|
assert.Equal(t, fake.issuer()+"/authorize", u.Scheme+"://"+u.Host+u.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCode_Success(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
fake.withState(func() { fake.expectVerifier = verifier })
|
||||||
|
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "fake-code", verifier)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "fake-access-token", tok.AccessToken)
|
||||||
|
assert.NotEmpty(t, tok.IDToken, "id_token 應有值")
|
||||||
|
assert.Equal(t, "Bearer", tok.TokenType)
|
||||||
|
assert.Greater(t, tok.ExpiresIn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCode_PKCEMismatch(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
fake.withState(func() { fake.expectVerifier = "the-real-verifier" })
|
||||||
|
|
||||||
|
_, err := p.ExchangeCode(context.Background(), "fake-code", "wrong-verifier")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidGrant, "PKCE 不符應對應到 ErrInvalidGrant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCode_ServerError(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.respondCode = http.StatusInternalServerError
|
||||||
|
fake.respondBody = `{"error":"server_error"}`
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := p.ExchangeCode(context.Background(), "fake-code", "any")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenExchange)
|
||||||
|
assert.NotErrorIs(t, err, ErrInvalidGrant, "5xx 不應被分類為 invalid_grant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExchangeCode_MissingIDToken(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
fake.withState(func() { fake.skipIDToken = true })
|
||||||
|
|
||||||
|
_, err := p.ExchangeCode(context.Background(), "fake-code", "any")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrTokenExchange, "缺 id_token 應視為 token_exchange 失敗")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_HappyPath(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
const expectedNonce = "nonce-happy-path"
|
||||||
|
now := time.Now()
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "user-123",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{
|
||||||
|
"email": "alice@example.com",
|
||||||
|
"name": "Alice",
|
||||||
|
"nonce": expectedNonce,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
claims, err := p.VerifyIDToken(context.Background(), tok.IDToken, expectedNonce)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "user-123", claims.Subject)
|
||||||
|
assert.Equal(t, "alice@example.com", claims.Email)
|
||||||
|
assert.Equal(t, "Alice", claims.Name)
|
||||||
|
assert.Equal(t, fake.issuer(), claims.Issuer)
|
||||||
|
assert.Equal(t, testClientID, claims.Audience)
|
||||||
|
assert.Equal(t, expectedNonce, claims.Nonce)
|
||||||
|
assert.False(t, claims.ExpiresAt.IsZero())
|
||||||
|
assert.NotEmpty(t, claims.Raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_WrongNonce(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "user-x",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{"nonce": "actual-nonce"}
|
||||||
|
})
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = p.VerifyIDToken(context.Background(), tok.IDToken, "expected-different-nonce")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidNonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_WrongAudience(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "user-x",
|
||||||
|
Audience: jwt.Audience{"some-other-client"}, // 故意錯
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{"nonce": "n1"}
|
||||||
|
})
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = p.VerifyIDToken(context.Background(), tok.IDToken, "n1")
|
||||||
|
require.Error(t, err)
|
||||||
|
// 涵蓋 audience 失敗 → 應映射到 ErrInvalidAudience 或至少 ErrInvalidIDToken
|
||||||
|
assert.True(t, errors.Is(err, ErrInvalidAudience) || errors.Is(err, ErrInvalidIDToken),
|
||||||
|
"audience 錯誤應對應到 ErrInvalidAudience(或 fallback ErrInvalidIDToken),得 %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_Expired(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
past := time.Now().Add(-1 * time.Hour)
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "user-x",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(past.Add(-5 * time.Minute)),
|
||||||
|
Expiry: jwt.NewNumericDate(past), // 已過期
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{"nonce": "n1"}
|
||||||
|
})
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = p.VerifyIDToken(context.Background(), tok.IDToken, "n1")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, ErrTokenExpired) || errors.Is(err, ErrInvalidIDToken),
|
||||||
|
"過期應對應到 ErrTokenExpired(或 fallback),得 %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_BadSignature(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "u",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{"nonce": "n1"}
|
||||||
|
})
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 把簽章部分整段替換成「保證無效」的 base64url 字串。
|
||||||
|
//
|
||||||
|
// 為什麼不用「翻最後一個字元」:base64url 末字元只承載部分原始 bit(取決於整段
|
||||||
|
// 長度對 3 取餘的結果),翻轉某些字元時 padding bit 不變,仍然會解碼出相同的
|
||||||
|
// 原始 bytes → 簽章值不變 → test 偶發 fail。改翻「中間字元」雖大幅降低風險
|
||||||
|
// 仍非 0;最穩定的作法是直接替換成完全不同的合法 base64url 字串,確保解碼出
|
||||||
|
// 來的 bytes 與原簽章不同。
|
||||||
|
parts := strings.Split(tok.IDToken, ".")
|
||||||
|
require.Len(t, parts, 3)
|
||||||
|
// 用相同長度的 'A' 串覆寫,base64url('A' * n) 解碼結果與原簽章 bytes 不同的機率近乎 100%
|
||||||
|
// (唯有原簽章本身就剛好全 0 才會碰撞,RSA 簽章機率為 0)。
|
||||||
|
parts[2] = strings.Repeat("A", len(parts[2]))
|
||||||
|
tampered := strings.Join(parts, ".")
|
||||||
|
|
||||||
|
_, err = p.VerifyIDToken(context.Background(), tampered, "n1")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidIDToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyIDToken_EmptyInputs(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake)
|
||||||
|
|
||||||
|
_, err := p.VerifyIDToken(context.Background(), "", "n1")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidIDToken)
|
||||||
|
|
||||||
|
// 空 nonce 也應拒絕(caller 必須提供)
|
||||||
|
now := time.Now()
|
||||||
|
fake.withState(func() {
|
||||||
|
fake.idTokenClaims = jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "u",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||||
|
}
|
||||||
|
fake.idTokenExtra = map[string]any{"nonce": "actual"}
|
||||||
|
})
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "code", "verifier")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = p.VerifyIDToken(context.Background(), tok.IDToken, "")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidNonce)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 確保 NewProvider 套用預設 Scopes(caller 沒填時)
|
||||||
|
func TestProviderConfig_DefaultScopes(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
ctx := context.Background()
|
||||||
|
cfg := ProviderConfig{
|
||||||
|
IssuerURL: fake.issuer(),
|
||||||
|
ClientID: testClientID,
|
||||||
|
ClientSecret: testClientSecret,
|
||||||
|
RedirectURL: testRedirect,
|
||||||
|
// 不填 Scopes
|
||||||
|
}
|
||||||
|
p, err := NewProvider(ctx, cfg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
state, _ := GenerateState()
|
||||||
|
nonce, _ := GenerateNonce()
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
raw := p.AuthorizationURL(state, nonce, CodeChallenge(verifier))
|
||||||
|
u, _ := url.Parse(raw)
|
||||||
|
scope := u.Query().Get("scope")
|
||||||
|
for _, s := range DefaultScopes {
|
||||||
|
assert.Truef(t, strings.Contains(scope, s), "預設 scope 應含 %q,得 %q", s, scope)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =====================================================================
|
||||||
|
// A1:Public PKCE-only client mode 測試
|
||||||
|
// =====================================================================
|
||||||
|
|
||||||
|
// TestNewProvider_PublicClient 驗 ClientSecret 留空時能成功初始化 Provider,
|
||||||
|
// 且行為與 confidential client 等價(除了 token request 的 auth method 不同)。
|
||||||
|
func TestNewProvider_PublicClient(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
p, err := NewProvider(ctx, ProviderConfig{
|
||||||
|
IssuerURL: fake.issuer(),
|
||||||
|
ClientID: testClientID,
|
||||||
|
// ClientSecret 故意留空 — A1:public PKCE-only client mode
|
||||||
|
RedirectURL: testRedirect,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, p)
|
||||||
|
|
||||||
|
// AuthorizationURL 應含 PKCE 參數但**不含** client_secret(OAuth 規格上 authorize 端點本來就不該帶)
|
||||||
|
state, _ := GenerateState()
|
||||||
|
nonce, _ := GenerateNonce()
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
authURL := p.AuthorizationURL(state, nonce, CodeChallenge(verifier))
|
||||||
|
u, err := url.Parse(authURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, u.Query().Get("client_secret"), "authorize URL 不應出現 client_secret")
|
||||||
|
assert.Equal(t, "S256", u.Query().Get("code_challenge_method"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestNewProvider_ConfidentialClient 是 baseline:確認既有 confidential mode 仍能初始化。
|
||||||
|
// 既有 TestNewProvider_Discovery 其實已涵蓋此情境,這個測試明示「兩種 mode 共存」。
|
||||||
|
func TestNewProvider_ConfidentialClient(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
p, err := NewProvider(ctx, ProviderConfig{
|
||||||
|
IssuerURL: fake.issuer(),
|
||||||
|
ClientID: testClientID,
|
||||||
|
ClientSecret: testClientSecret, // 有值 → confidential
|
||||||
|
RedirectURL: testRedirect,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExchangeCode_PublicClientNoSecretSent 驗 public client mode 下,
|
||||||
|
// /token request 的 form 不含 client_secret 欄位、且 Authorization header 不是 Basic auth。
|
||||||
|
//
|
||||||
|
// 這是 A1 改造的核心驗證:oauth2 lib 在 ClientSecret="" + AuthStyleInParams 時,
|
||||||
|
// 完全不送 client_secret(符合 RFC 6749 §2.3.1 對 public client 的規範)。
|
||||||
|
func TestExchangeCode_PublicClientNoSecretSent(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
p, err := NewProvider(ctx, ProviderConfig{
|
||||||
|
IssuerURL: fake.issuer(),
|
||||||
|
ClientID: testClientID,
|
||||||
|
// ClientSecret 留空 → public PKCE-only client
|
||||||
|
RedirectURL: testRedirect,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
fake.withState(func() { fake.expectVerifier = verifier })
|
||||||
|
|
||||||
|
tok, err := p.ExchangeCode(context.Background(), "fake-code", verifier)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, tok.IDToken)
|
||||||
|
|
||||||
|
// 取出 fake server 觀察到的 token request
|
||||||
|
var (
|
||||||
|
form url.Values
|
||||||
|
auth string
|
||||||
|
ok bool
|
||||||
|
)
|
||||||
|
fake.withState(func() {
|
||||||
|
form = fake.lastTokenForm
|
||||||
|
auth = fake.lastTokenAuthHdr
|
||||||
|
ok = form != nil
|
||||||
|
})
|
||||||
|
require.True(t, ok, "fake server 應該已收到一次 token request")
|
||||||
|
|
||||||
|
// public client 應該:
|
||||||
|
// - form 帶 client_id 但**不帶** client_secret
|
||||||
|
// - 不送 Authorization: Basic header(Authorization 應為空字串)
|
||||||
|
// - 帶 code_verifier(PKCE proof)
|
||||||
|
assert.Equal(t, testClientID, form.Get("client_id"), "public client 仍需在 form 帶 client_id")
|
||||||
|
assert.Empty(t, form.Get("client_secret"), "public client 不應送 client_secret form 欄位")
|
||||||
|
assert.Empty(t, auth, "public client 不應送 Authorization header(更不應是 Basic)")
|
||||||
|
assert.Equal(t, verifier, form.Get("code_verifier"), "PKCE verifier 必須帶")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestExchangeCode_ConfidentialClientSendsSecret 對照組:
|
||||||
|
// confidential client mode 下,/token request 必須帶 client_secret(或 Basic auth)。
|
||||||
|
func TestExchangeCode_ConfidentialClientSendsSecret(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
p := newProviderForTest(t, fake) // 用既有 helper:帶 testClientSecret
|
||||||
|
|
||||||
|
verifier, _ := GenerateCodeVerifier()
|
||||||
|
fake.withState(func() { fake.expectVerifier = verifier })
|
||||||
|
|
||||||
|
_, err := p.ExchangeCode(context.Background(), "fake-code", verifier)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var (
|
||||||
|
form url.Values
|
||||||
|
auth string
|
||||||
|
)
|
||||||
|
fake.withState(func() {
|
||||||
|
form = fake.lastTokenForm
|
||||||
|
auth = fake.lastTokenAuthHdr
|
||||||
|
})
|
||||||
|
|
||||||
|
// confidential 可以走兩種:Basic auth header,或 form 帶 client_secret。
|
||||||
|
// oauth2 lib 預設先試 InHeader → 第一輪通常是 Basic。
|
||||||
|
// 我們不挑路徑,只要「至少一邊」帶 secret 就算正確。
|
||||||
|
hasFormSecret := form.Get("client_secret") == testClientSecret
|
||||||
|
hasBasicAuth := strings.HasPrefix(auth, "Basic ")
|
||||||
|
assert.Truef(t, hasFormSecret || hasBasicAuth,
|
||||||
|
"confidential client 應透過 form 或 Basic auth 帶 secret;form=%v auth=%q",
|
||||||
|
form, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanity check:fakeOIDC 自身沒寫錯(簽出來的 token coreos 能驗)
|
||||||
|
func TestFakeOIDC_SelfSignedTokenIsValid(t *testing.T) {
|
||||||
|
fake := newFakeOIDC(t, testClientID)
|
||||||
|
now := time.Now()
|
||||||
|
std := jwt.Claims{
|
||||||
|
Issuer: fake.issuer(),
|
||||||
|
Subject: "self-test",
|
||||||
|
Audience: jwt.Audience{testClientID},
|
||||||
|
IssuedAt: jwt.NewNumericDate(now),
|
||||||
|
Expiry: jwt.NewNumericDate(now.Add(time.Minute)),
|
||||||
|
}
|
||||||
|
tok, err := signIDToken(fake.signingKey, fake.keyID, jose.RS256, std,
|
||||||
|
map[string]any{"nonce": "test-nonce"})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, strings.Count(tok, ".") == 2, "JWT 應有 3 段")
|
||||||
|
// 確保非空 payload
|
||||||
|
parts := strings.SplitN(tok, ".", 3)
|
||||||
|
require.Len(t, parts, 3)
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(payload), "self-test")
|
||||||
|
assert.Contains(t, string(payload), "test-nonce")
|
||||||
|
// 也驗證 fmt 可印(避免 unused import 'fmt' 在小幅 refactor 後消失)
|
||||||
|
_ = fmt.Sprintf("%s", tok)
|
||||||
|
}
|
||||||
87
visionA-backend/internal/oidctest/flow.go
Normal file
87
visionA-backend/internal/oidctest/flow.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
// Package oidctest — flow.go
|
||||||
|
//
|
||||||
|
// 提供「站在 caller (BFF backend) 角度」模擬完整 OIDC redirect flow 的 helper。
|
||||||
|
// 主要用於 e2e 整合測試:把「使用者打開瀏覽器、輸入帳密、按下同意」這幾個人工步驟
|
||||||
|
// 黑箱化成一個函式呼叫。
|
||||||
|
|
||||||
|
package oidctest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SimulateAuthorizationFlow 模擬「使用者打開 /authorize → 同意登入 → IdP 302 回 redirect_uri」。
|
||||||
|
//
|
||||||
|
// 給定一個由 visionA-backend 產出的 authorize URL(內含 client_id / redirect_uri /
|
||||||
|
// state / code_challenge / nonce),本函式:
|
||||||
|
//
|
||||||
|
// 1. 對 fake server 的 /authorize 發 GET,禁止 redirect 自動跟隨
|
||||||
|
// 2. fake server 會 302 回 redirect_uri?code=<code>&state=<state>
|
||||||
|
// 3. 把 Location header 取出回傳 — 這就是 caller 接著要打的 callback URL
|
||||||
|
//
|
||||||
|
// caller 通常拿到 callback URL 之後,會「以 BFF backend client 的角色」打
|
||||||
|
// /api/auth/callback?code=...&state=...,讓 BFF 完成 token exchange + 建 session。
|
||||||
|
//
|
||||||
|
// 用 *testing.T 直接 Fatalf 而非回 error,是因為 e2e test 寫法統一:
|
||||||
|
// 任何模擬步驟出錯都應該讓 test 立即停。caller 不必到處檢查 err。
|
||||||
|
func (s *Server) SimulateAuthorizationFlow(t *testing.T, authorizeURL string) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if authorizeURL == "" {
|
||||||
|
t.Fatalf("oidctest: SimulateAuthorizationFlow: authorizeURL is empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用一個拒絕 redirect 的 client,這樣我們才能取到 Location header。
|
||||||
|
client := &http.Client{
|
||||||
|
CheckRedirect: func(req *http.Request, via []*http.Request) error {
|
||||||
|
return http.ErrUseLastResponse
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequest(http.MethodGet, authorizeURL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("oidctest: build authorize request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("oidctest: GET /authorize failed: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusFound && resp.StatusCode != http.StatusSeeOther {
|
||||||
|
t.Fatalf("oidctest: /authorize 預期 302/303,得 %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
loc := resp.Header.Get("Location")
|
||||||
|
if loc == "" {
|
||||||
|
t.Fatalf("oidctest: /authorize 回應缺 Location header")
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanity check:確認 Location 是合法 URL 且帶 code 參數
|
||||||
|
u, err := url.Parse(loc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("oidctest: callback URL 不是合法 URL: %v", err)
|
||||||
|
}
|
||||||
|
if u.Query().Get("code") == "" {
|
||||||
|
t.Fatalf("oidctest: callback URL 缺 code 參數: %s", loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
return loc
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeRedirectError 是 ForceAuthorizeFailure 模擬「IdP 直接拒絕授權」場景時的回傳 error。
|
||||||
|
// 不直接讓 fake server 回 4xx 是因為真 IdP 通常會 302 帶 ?error=... 回 redirect_uri,
|
||||||
|
// 讓 caller (BFF) 自行處理。
|
||||||
|
type AuthorizeRedirectError struct {
|
||||||
|
Error string
|
||||||
|
ErrorDescription string
|
||||||
|
}
|
||||||
|
|
||||||
|
// String 讓 caller 容易在 test failure 訊息中顯示。
|
||||||
|
func (e AuthorizeRedirectError) String() string {
|
||||||
|
return fmt.Sprintf("authorize_error code=%q desc=%q", e.Error, e.ErrorDescription)
|
||||||
|
}
|
||||||
600
visionA-backend/internal/oidctest/server.go
Normal file
600
visionA-backend/internal/oidctest/server.go
Normal file
@ -0,0 +1,600 @@
|
|||||||
|
// Package oidctest 提供測試專用的 fake OpenID Connect Identity Provider,
|
||||||
|
// 模擬 Innovedus Member Center 的對外行為(OIDC discovery / JWKS / token exchange /
|
||||||
|
// authorize redirect),讓 visionA-backend 的 OIDC client 與 BFF 流程可以在純
|
||||||
|
// in-process 環境完成 end-to-end 整合測試 — 不需要 docker、不需要 docker-compose、
|
||||||
|
// 不需要真的 Member Center。
|
||||||
|
//
|
||||||
|
// # 設計理由
|
||||||
|
//
|
||||||
|
// OB1(internal/oidc)已經為自己的 unit test 寫過一份 fake OIDC server
|
||||||
|
// (internal/oidc/provider_test.go 裡的 fakeOIDC)。OT1 把它「再寫一次成可重用的
|
||||||
|
// 公開 package」而不是直接 export 那個 fake,理由有二:
|
||||||
|
//
|
||||||
|
// 1. 邊界乾淨:OB1 的 fakeOIDC 是 unit test 用的(檔名 *_test.go 不會出現在
|
||||||
|
// production binary 也不會出現在其他 package 的 import path),刻意不拿來當
|
||||||
|
// 公開 fixture 是為了讓 unit test 自成一格、不被外部依賴牽動。
|
||||||
|
//
|
||||||
|
// 2. API 形狀不同:unit test 的 fake 暴露很多 hook(snapshot/withState/skipIDToken
|
||||||
|
// …)給「驗測 OIDC client 的錯誤分類」這種白箱測試用;e2e 整合測試需要的是
|
||||||
|
// 「黑箱模擬完整 BFF flow」— issuer URL、ExchangeCode / authorize-redirect 的
|
||||||
|
// 整體行為。兩邊的 API 形狀一旦混用反而綁手綁腳。
|
||||||
|
//
|
||||||
|
// 實作仍刻意對齊 OB1 的 fakeOIDC(同樣 RS256 / 同樣 endpoint paths /
|
||||||
|
// 同樣 PKCE 與 nonce 處理),如果未來雙方有差異要對齊,更新本 package 即可。
|
||||||
|
//
|
||||||
|
// # 對齊文件
|
||||||
|
//
|
||||||
|
// - oidc-tdd.md §3 BFF Flow 詳細時序圖
|
||||||
|
// - oidc-tdd.md §6 PKCE 實作細節
|
||||||
|
// - oidc-tdd.md §7 id_token 驗證
|
||||||
|
// - adr-010-oidc-bff.md
|
||||||
|
//
|
||||||
|
// # 使用範例
|
||||||
|
//
|
||||||
|
// srv := oidctest.NewServer(t,
|
||||||
|
// oidctest.WithClientCredentials("visiona-backend-test", "test-secret"),
|
||||||
|
// )
|
||||||
|
// defer srv.Close()
|
||||||
|
//
|
||||||
|
// // 把 srv.URL 當作 IssuerURL 給 visionA-backend 的 OIDC provider。
|
||||||
|
// provider, _ := oidc.NewProvider(ctx, oidc.ProviderConfig{
|
||||||
|
// IssuerURL: srv.URL,
|
||||||
|
// ClientID: srv.ClientID,
|
||||||
|
// ClientSecret: srv.ClientSecret,
|
||||||
|
// RedirectURL: "http://localhost:8080/api/auth/callback",
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // 預先告知 fake server:下一個 ExchangeCode 之後要簽發的 id_token claims
|
||||||
|
// srv.SetNextIDTokenClaims(map[string]any{
|
||||||
|
// "sub": "user-from-mc",
|
||||||
|
// "email": "alice@innovedus.com",
|
||||||
|
// "name": "Alice",
|
||||||
|
// })
|
||||||
|
package oidctest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Server 是用 httptest.Server 包起來的 fake OIDC IdP。
|
||||||
|
//
|
||||||
|
// 提供以下端點:
|
||||||
|
// - GET /.well-known/openid-configuration → discovery doc
|
||||||
|
// - GET /jwks → JWKS(含 1 把 RSA public key)
|
||||||
|
// - POST /oauth/token → form-encoded code exchange,回 id_token + access_token
|
||||||
|
// - GET /authorize → 自動「同意登入」並 redirect 回 client redirect_uri 帶 code
|
||||||
|
//
|
||||||
|
// Server 是 goroutine-safe(內部 RWMutex 保護所有 mutable state),
|
||||||
|
// 但 caller 仍應在「設定下一輪 token 行為 → 觸發 ExchangeCode」之間以同步方式進行,
|
||||||
|
// 否則 race 行為可預期但難以推理。
|
||||||
|
type Server struct {
|
||||||
|
// URL 是 httptest.Server 的 URL,同時也是 OIDC discovery 的 issuer。
|
||||||
|
// caller 直接拿這個當 ProviderConfig.IssuerURL。
|
||||||
|
URL string
|
||||||
|
|
||||||
|
// Issuer 等同 URL;保留欄位是為了測試「issuer mismatch」場景時能暫時 override。
|
||||||
|
Issuer string
|
||||||
|
|
||||||
|
// ClientID 與 ClientSecret 是 fake server 認可的 confidential client 憑證。
|
||||||
|
// /oauth/token 會驗 client_secret;不符直接 401。
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// PrivateKey 是用來簽 id_token 的 RSA key;JWKS endpoint 公開對應的 public key。
|
||||||
|
// caller 一般不需要直接用,但 IssueIDToken 暴露給「自定 token claims」場景。
|
||||||
|
PrivateKey *rsa.PrivateKey
|
||||||
|
KeyID string
|
||||||
|
|
||||||
|
httpServer *httptest.Server
|
||||||
|
|
||||||
|
// ─── mutable state(test 之間用 SetXXX 改寫;rwmu 保護) ───
|
||||||
|
rwmu sync.RWMutex
|
||||||
|
|
||||||
|
// nextIDTokenClaims:若非 nil,下一個 /oauth/token response 的 id_token 用這份 claims 簽發。
|
||||||
|
// caller 通常在每個 test 開頭 Set 一次;handleToken 用完不會自動清空,方便同個
|
||||||
|
// fake server 在多次 ExchangeCode 中重複簽發同一個使用者的 token。
|
||||||
|
nextIDTokenClaims map[string]any
|
||||||
|
|
||||||
|
// nextAccessToken:若非空字串,下一個 /oauth/token response 的 access_token 用這個值。
|
||||||
|
// 預設為 "fake-access-token"。
|
||||||
|
nextAccessToken string
|
||||||
|
|
||||||
|
// 觀測欄位:記錄最後一次 /authorize 與 /oauth/token 收到的關鍵參數,
|
||||||
|
// e2e test 可用來驗 BFF 是否把 PKCE / state / nonce 正確帶過來。
|
||||||
|
lastAuthorizeQuery url.Values
|
||||||
|
lastTokenForm url.Values
|
||||||
|
|
||||||
|
// codeStore 是 authorization code → 該 code 對應的 PKCE challenge / nonce 的暫存。
|
||||||
|
// 模擬真 IdP 會把 code 與當時的 PKCE challenge 綁定,token endpoint 才能驗 PKCE proof。
|
||||||
|
codeStore map[string]issuedCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// issuedCode 是 SimulateAuthorizationFlow / IssueAuthCode 簽發 code 時記下的元資料。
|
||||||
|
// 之後 /oauth/token 用 code 反查出當時的 challenge 比對 PKCE。
|
||||||
|
type issuedCode struct {
|
||||||
|
CodeChallenge string
|
||||||
|
CodeChallengeMethod string
|
||||||
|
Nonce string
|
||||||
|
RedirectURI string
|
||||||
|
ClientID string
|
||||||
|
IssuedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Option 是 NewServer 的功能選項。
|
||||||
|
type Option func(*Server)
|
||||||
|
|
||||||
|
// WithClientCredentials 設定 fake server 認可的 OAuth client_id / client_secret。
|
||||||
|
// 不呼叫此 option 則使用預設值(visiona-backend-test / test-secret)。
|
||||||
|
func WithClientCredentials(clientID, clientSecret string) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.ClientID = clientID
|
||||||
|
s.ClientSecret = clientSecret
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithIssuer 強制 override discovery doc 的 issuer claim。
|
||||||
|
// 一般場景請勿使用;此選項只給「測 issuer mismatch」這種對抗性測試用。
|
||||||
|
func WithIssuer(issuer string) Option {
|
||||||
|
return func(s *Server) {
|
||||||
|
s.Issuer = issuer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer 啟動一個 fake OIDC server。
|
||||||
|
//
|
||||||
|
// 會立即在 t.Cleanup 註冊關閉動作,caller 不必自己呼叫 Close
|
||||||
|
// (但保留 Close 公開方法供「同個 test 內提早關閉以驗錯誤情境」使用)。
|
||||||
|
func NewServer(t *testing.T, opts ...Option) *Server {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("oidctest: rsa key gen failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &Server{
|
||||||
|
PrivateKey: priv,
|
||||||
|
KeyID: "oidctest-key-1",
|
||||||
|
ClientID: "visiona-backend-test",
|
||||||
|
ClientSecret: "test-secret",
|
||||||
|
nextAccessToken: "fake-access-token",
|
||||||
|
codeStore: make(map[string]issuedCode),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/.well-known/openid-configuration", s.handleDiscovery)
|
||||||
|
mux.HandleFunc("/jwks", s.handleJWKS)
|
||||||
|
mux.HandleFunc("/oauth/token", s.handleToken)
|
||||||
|
mux.HandleFunc("/authorize", s.handleAuthorize)
|
||||||
|
|
||||||
|
s.httpServer = httptest.NewServer(mux)
|
||||||
|
s.URL = s.httpServer.URL
|
||||||
|
if s.Issuer == "" {
|
||||||
|
s.Issuer = s.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Cleanup(s.Close)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 停掉內部 httptest.Server。重複呼叫安全。
|
||||||
|
func (s *Server) Close() {
|
||||||
|
if s.httpServer != nil {
|
||||||
|
s.httpServer.Close()
|
||||||
|
s.httpServer = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetNextIDTokenClaims 設定下一次 /oauth/token 回應裡 id_token 的 claims。
|
||||||
|
//
|
||||||
|
// 注意:傳入的 map 不會被 merge,而是「整份覆蓋」預設值(除了 iss/aud/exp 由 server 補足)。
|
||||||
|
// caller 可以放 sub / email / name / nonce 等自定 claim;若漏傳 nonce,handleToken
|
||||||
|
// 會用 lastAuthorizeQuery 收到的 nonce 補上(模擬真 IdP 行為:authorize 收到的 nonce 會回灌到 id_token)。
|
||||||
|
//
|
||||||
|
// 若傳 nil 等同呼叫 ResetIDTokenClaims,回到預設 sub。
|
||||||
|
func (s *Server) SetNextIDTokenClaims(claims map[string]any) {
|
||||||
|
s.rwmu.Lock()
|
||||||
|
defer s.rwmu.Unlock()
|
||||||
|
if claims == nil {
|
||||||
|
s.nextIDTokenClaims = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cp := make(map[string]any, len(claims))
|
||||||
|
for k, v := range claims {
|
||||||
|
cp[k] = v
|
||||||
|
}
|
||||||
|
s.nextIDTokenClaims = cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetIDTokenClaims 把「下一輪 id_token claims」回到預設值。
|
||||||
|
func (s *Server) ResetIDTokenClaims() { s.SetNextIDTokenClaims(nil) }
|
||||||
|
|
||||||
|
// SetNextAccessToken 改下一次 /oauth/token response 的 access_token 字串。
|
||||||
|
// 主要供「驗 backend 是否正確存了 access_token」這種測試。
|
||||||
|
func (s *Server) SetNextAccessToken(tok string) {
|
||||||
|
s.rwmu.Lock()
|
||||||
|
defer s.rwmu.Unlock()
|
||||||
|
s.nextAccessToken = tok
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastAuthorizeQuery 回傳上一次 /authorize 收到的 query string(複製,可安全修改)。
|
||||||
|
// e2e test 通常用來驗 BFF 是否正確產 PKCE / state / nonce。
|
||||||
|
func (s *Server) LastAuthorizeQuery() url.Values {
|
||||||
|
s.rwmu.RLock()
|
||||||
|
defer s.rwmu.RUnlock()
|
||||||
|
return cloneValues(s.lastAuthorizeQuery)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LastTokenForm 回傳上一次 /oauth/token 收到的 form value(複製)。
|
||||||
|
// 用來驗 ExchangeCode 是否正確帶 client_secret / code_verifier。
|
||||||
|
func (s *Server) LastTokenForm() url.Values {
|
||||||
|
s.rwmu.RLock()
|
||||||
|
defer s.rwmu.RUnlock()
|
||||||
|
return cloneValues(s.lastTokenForm)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueIDToken 直接用 fake server 的 RSA private key 簽一個 id_token,回傳 raw JWT 字串。
|
||||||
|
//
|
||||||
|
// 用途:少數場景需要「跳過 token endpoint,直接拿 id_token 餵給 VerifyIDToken」測試
|
||||||
|
// (例如測 backend 對「不正確 issuer 的 id_token」的拒絕行為)。
|
||||||
|
func (s *Server) IssueIDToken(claims map[string]any) (string, error) {
|
||||||
|
return signJWT(s.PrivateKey, s.KeyID, jose.RS256, claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IssueAuthCode 預先簽發一個 authorization code,並把對應的 PKCE challenge / nonce
|
||||||
|
// 記在 codeStore 中。後續 /oauth/token 收到此 code + 正確 code_verifier 才會放行。
|
||||||
|
//
|
||||||
|
// 主要供「不走完整 redirect 流程、直接構造 callback」的測試用。
|
||||||
|
// 如果你只是要 e2e 跑完整 flow,呼叫 SimulateAuthorizationFlow 即可(會自動 issue code)。
|
||||||
|
func (s *Server) IssueAuthCode(challenge, challengeMethod, nonce, redirectURI string) (string, error) {
|
||||||
|
code, err := randomURLToken(24)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
s.rwmu.Lock()
|
||||||
|
defer s.rwmu.Unlock()
|
||||||
|
s.codeStore[code] = issuedCode{
|
||||||
|
CodeChallenge: challenge,
|
||||||
|
CodeChallengeMethod: challengeMethod,
|
||||||
|
Nonce: nonce,
|
||||||
|
RedirectURI: redirectURI,
|
||||||
|
ClientID: s.ClientID,
|
||||||
|
IssuedAt: time.Now(),
|
||||||
|
}
|
||||||
|
return code, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── HTTP handlers ─────────────────────────
|
||||||
|
|
||||||
|
func (s *Server) handleDiscovery(w http.ResponseWriter, r *http.Request) {
|
||||||
|
s.rwmu.RLock()
|
||||||
|
issuer := s.Issuer
|
||||||
|
s.rwmu.RUnlock()
|
||||||
|
|
||||||
|
doc := map[string]any{
|
||||||
|
"issuer": issuer,
|
||||||
|
"authorization_endpoint": s.URL + "/authorize",
|
||||||
|
"token_endpoint": s.URL + "/oauth/token",
|
||||||
|
"jwks_uri": s.URL + "/jwks",
|
||||||
|
"response_types_supported": []string{"code"},
|
||||||
|
"id_token_signing_alg_values_supported": []string{"RS256"},
|
||||||
|
"subject_types_supported": []string{"public"},
|
||||||
|
"scopes_supported": []string{"openid", "email", "profile"},
|
||||||
|
"code_challenge_methods_supported": []string{"S256"},
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, doc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handleJWKS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
jwks := map[string]any{
|
||||||
|
"keys": []map[string]any{
|
||||||
|
rsaPublicKeyToJWK(&s.PrivateKey.PublicKey, s.KeyID),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, jwks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAuthorize 模擬「使用者打開 /authorize → 同意登入 → IdP redirect 回 client redirect_uri 帶 code」。
|
||||||
|
//
|
||||||
|
// 為了讓測試可以「不開瀏覽器」就跑通整段,我們不顯示登入頁,
|
||||||
|
// 而是直接把 code 帶上 redirect_uri 立刻 302 回去(等同「使用者已存在 SSO 並自動同意」)。
|
||||||
|
//
|
||||||
|
// 我們同時把 code 與當時的 PKCE challenge / nonce 綁起來,
|
||||||
|
// 後續 /oauth/token 才能驗 PKCE proof,符合真 IdP 行為。
|
||||||
|
func (s *Server) handleAuthorize(w http.ResponseWriter, r *http.Request) {
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
s.rwmu.Lock()
|
||||||
|
s.lastAuthorizeQuery = cloneValues(q)
|
||||||
|
s.rwmu.Unlock()
|
||||||
|
|
||||||
|
redirectURI := q.Get("redirect_uri")
|
||||||
|
state := q.Get("state")
|
||||||
|
challenge := q.Get("code_challenge")
|
||||||
|
challengeMethod := q.Get("code_challenge_method")
|
||||||
|
nonce := q.Get("nonce")
|
||||||
|
clientID := q.Get("client_id")
|
||||||
|
|
||||||
|
// 基本驗:redirect_uri / client_id 缺則 400。
|
||||||
|
// 真 IdP 還會驗 redirect_uri 是否在註冊白名單;fake server 簡化掉,反正測試 caller 一定會
|
||||||
|
// 帶正確的 redirect_uri。
|
||||||
|
if redirectURI == "" || clientID == "" {
|
||||||
|
http.Error(w, "missing redirect_uri or client_id", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 簽發 code(記 challenge / nonce 進 codeStore)
|
||||||
|
code, err := s.IssueAuthCode(challenge, challengeMethod, nonce, redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "issue code failed: "+err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 組 callback URL:redirect_uri?code=<code>&state=<state>
|
||||||
|
cbURL, err := url.Parse(redirectURI)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, "invalid redirect_uri: "+err.Error(), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rq := cbURL.Query()
|
||||||
|
rq.Set("code", code)
|
||||||
|
if state != "" {
|
||||||
|
rq.Set("state", state)
|
||||||
|
}
|
||||||
|
cbURL.RawQuery = rq.Encode()
|
||||||
|
|
||||||
|
http.Redirect(w, r, cbURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleToken 處理 POST /oauth/token(authorization_code grant)。
|
||||||
|
//
|
||||||
|
// 流程:
|
||||||
|
// 1. 驗 client_id / client_secret
|
||||||
|
// 2. 驗 grant_type == authorization_code
|
||||||
|
// 3. 從 codeStore 取出 code 對應的 challenge / nonce
|
||||||
|
// 4. 驗 PKCE:sha256(code_verifier) == challenge(base64url 比對)
|
||||||
|
// 5. 簽 id_token(用 nextIDTokenClaims 或預設 claims;nonce 自動補入)
|
||||||
|
// 6. 回 token response
|
||||||
|
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "invalid_request", "parse form: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTP Basic auth or form fields — 兩種都支援,跟真 OIDC IdP 一致
|
||||||
|
clientID, clientSecret := extractClientCredentials(r)
|
||||||
|
|
||||||
|
s.rwmu.Lock()
|
||||||
|
s.lastTokenForm = cloneValues(r.Form)
|
||||||
|
s.rwmu.Unlock()
|
||||||
|
|
||||||
|
// ─── 1. client credentials ───
|
||||||
|
if clientID != s.ClientID || clientSecret != s.ClientSecret {
|
||||||
|
writeOAuthError(w, http.StatusUnauthorized, "invalid_client", "client credentials mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 2. grant type ───
|
||||||
|
if gt := r.Form.Get("grant_type"); gt != "authorization_code" {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "unsupported_grant_type", "got "+gt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 3. code 取對應 metadata ───
|
||||||
|
code := r.Form.Get("code")
|
||||||
|
if code == "" {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "missing code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.rwmu.Lock()
|
||||||
|
meta, ok := s.codeStore[code]
|
||||||
|
if ok {
|
||||||
|
// 真 IdP 的 code 是「一次性」— 用過就刪,避免 replay。
|
||||||
|
delete(s.codeStore, code)
|
||||||
|
}
|
||||||
|
s.rwmu.Unlock()
|
||||||
|
if !ok {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "unknown or already-used code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 4. PKCE ───
|
||||||
|
verifier := r.Form.Get("code_verifier")
|
||||||
|
if meta.CodeChallenge != "" {
|
||||||
|
if verifier == "" {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "code_verifier required (PKCE was used)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !verifyPKCE(verifier, meta.CodeChallenge, meta.CodeChallengeMethod) {
|
||||||
|
writeOAuthError(w, http.StatusBadRequest, "invalid_grant", "PKCE verifier mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── 5. 簽 id_token ───
|
||||||
|
s.rwmu.RLock()
|
||||||
|
customClaims := cloneClaims(s.nextIDTokenClaims)
|
||||||
|
accessToken := s.nextAccessToken
|
||||||
|
issuer := s.Issuer
|
||||||
|
s.rwmu.RUnlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
claims := map[string]any{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": s.ClientID,
|
||||||
|
"iat": now.Unix(),
|
||||||
|
"exp": now.Add(5 * time.Minute).Unix(),
|
||||||
|
"nbf": now.Unix(),
|
||||||
|
"sub": "sub-fake-user-001",
|
||||||
|
"email": "fake-user@example.com",
|
||||||
|
"name": "Fake User",
|
||||||
|
}
|
||||||
|
// 把 caller 指定的 claims 蓋過預設(sub/email/name 等)
|
||||||
|
for k, v := range customClaims {
|
||||||
|
claims[k] = v
|
||||||
|
}
|
||||||
|
// 永遠把 authorize 收到的 nonce 灌回去(除非 caller 已經自行指定)
|
||||||
|
if _, has := claims["nonce"]; !has && meta.Nonce != "" {
|
||||||
|
claims["nonce"] = meta.Nonce
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := signJWT(s.PrivateKey, s.KeyID, jose.RS256, claims)
|
||||||
|
if err != nil {
|
||||||
|
writeOAuthError(w, http.StatusInternalServerError, "server_error", "sign id_token: "+err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := map[string]any{
|
||||||
|
"access_token": accessToken,
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": idToken,
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── helpers ─────────────────────────
|
||||||
|
|
||||||
|
// extractClientCredentials 從 HTTP Basic auth 或 form field 取出 client_id / client_secret。
|
||||||
|
// 真 IdP 通常兩種都接(RFC 6749 §2.3.1),fake server 也比照辦理。
|
||||||
|
func extractClientCredentials(r *http.Request) (string, string) {
|
||||||
|
if cid, csec, ok := r.BasicAuth(); ok && cid != "" {
|
||||||
|
return cid, csec
|
||||||
|
}
|
||||||
|
return r.Form.Get("client_id"), r.Form.Get("client_secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyPKCE 對照 RFC 7636 §4.6 規定驗 code_verifier 對 challenge:
|
||||||
|
//
|
||||||
|
// S256: BASE64URL(SHA256(verifier)) == challenge
|
||||||
|
//
|
||||||
|
// 不支援 plain(OAuth 2.1 已 deprecated;fake server 也只認 S256)。
|
||||||
|
func verifyPKCE(verifier, challenge, method string) bool {
|
||||||
|
if method == "" {
|
||||||
|
method = "S256" // 真 IdP 的預設可能不同,但 visionA 一律用 S256
|
||||||
|
}
|
||||||
|
if method != "S256" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
expected := pkceS256(verifier)
|
||||||
|
return expected == challenge
|
||||||
|
}
|
||||||
|
|
||||||
|
// pkceS256:BASE64URL(SHA256(verifier))。重新實作而不 import internal/oidc,
|
||||||
|
// 保持 oidctest 不依賴 production package(避免循環依賴 + 確保 oidctest 可被
|
||||||
|
// internal/oidc 自己未來想用而不打死)。
|
||||||
|
func pkceS256(verifier string) string {
|
||||||
|
sum := sha256.Sum256([]byte(verifier))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeOAuthError 寫 RFC 6749 §5.2 規範的 token error response。
|
||||||
|
func writeOAuthError(w http.ResponseWriter, status int, code, desc string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||||
|
"error": code,
|
||||||
|
"error_description": desc,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneValues 複製 url.Values,避免 caller 改到我們存的觀測值。
|
||||||
|
func cloneValues(in url.Values) url.Values {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(url.Values, len(in))
|
||||||
|
for k, vs := range in {
|
||||||
|
cp := make([]string, len(vs))
|
||||||
|
copy(cp, vs)
|
||||||
|
out[k] = cp
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneClaims 複製 claims map,避免並發 race。
|
||||||
|
func cloneClaims(in map[string]any) map[string]any {
|
||||||
|
if in == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make(map[string]any, len(in))
|
||||||
|
for k, v := range in {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// rsaPublicKeyToJWK 把 RSA public key 編成 JWKS spec 的 key 物件。
|
||||||
|
// 與 OB1 的 fakeOIDC 寫法一致(base64url 無 padding;exponent 手動轉 byte slice)。
|
||||||
|
func rsaPublicKeyToJWK(pub *rsa.PublicKey, kid string) map[string]any {
|
||||||
|
return map[string]any{
|
||||||
|
"kty": "RSA",
|
||||||
|
"alg": "RS256",
|
||||||
|
"use": "sig",
|
||||||
|
"kid": kid,
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(pub.N.Bytes()),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString(bigIntBytes(pub.E)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bigIntBytes(e int) []byte {
|
||||||
|
out := []byte{}
|
||||||
|
for e > 0 {
|
||||||
|
out = append([]byte{byte(e & 0xff)}, out...)
|
||||||
|
e >>= 8
|
||||||
|
}
|
||||||
|
if len(out) == 0 {
|
||||||
|
out = []byte{0}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// signJWT 用 RSA private key 簽出 RS256 JWT。
|
||||||
|
//
|
||||||
|
// 接受 map[string]any 而非 jwt.Claims struct,方便 caller 灌任意 claim
|
||||||
|
// (包含 OIDC 的 sub/email/name 與測試用的非標準欄位)。
|
||||||
|
func signJWT(priv *rsa.PrivateKey, kid string, alg jose.SignatureAlgorithm, claims map[string]any) (string, error) {
|
||||||
|
signerOpts := (&jose.SignerOptions{}).WithType("JWT")
|
||||||
|
signerOpts.WithHeader("kid", kid)
|
||||||
|
signer, err := jose.NewSigner(jose.SigningKey{Algorithm: alg, Key: priv}, signerOpts)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("oidctest: new signer: %w", err)
|
||||||
|
}
|
||||||
|
tok, err := jwt.Signed(signer).Claims(claims).Serialize()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("oidctest: sign jwt: %w", err)
|
||||||
|
}
|
||||||
|
return tok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// randomURLToken 產生 base64url 編碼的隨機 token,給 authorization code 用。
|
||||||
|
func randomURLToken(nBytes int) (string, error) {
|
||||||
|
b := make([]byte, nBytes)
|
||||||
|
if _, err := rand.Read(b); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||||
|
}
|
||||||
428
visionA-backend/internal/oidctest/server_test.go
Normal file
428
visionA-backend/internal/oidctest/server_test.go
Normal file
@ -0,0 +1,428 @@
|
|||||||
|
// Package oidctest 的自我驗證測試。
|
||||||
|
//
|
||||||
|
// 這個檔案只測 oidctest package「自己」的行為,不涉及 visionA-backend 任何 production code。
|
||||||
|
// 目的:在 oidctest 被 e2e test 大量依賴之前,先單獨確保它每個 endpoint 都符合預期 —
|
||||||
|
// 否則 e2e test 失敗時很難區分「OIDC client 寫錯」還是「fake server 寫錯」。
|
||||||
|
//
|
||||||
|
// 不重複測 OB1 的 internal/oidc/provider_test.go 已涵蓋的「provider 串接 fake server」場景,
|
||||||
|
// 這邊純粹從 HTTP wire 層驗 fake server 的對外 contract。
|
||||||
|
|
||||||
|
package oidctest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/go-jose/go-jose/v4/jwt"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testRedirectURI = "http://localhost:8080/api/auth/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ───────────────────────── Discovery ─────────────────────────
|
||||||
|
|
||||||
|
func TestServer_Discovery_Endpoints(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
resp, err := http.Get(srv.URL + "/.well-known/openid-configuration")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var doc map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&doc))
|
||||||
|
|
||||||
|
// issuer 必須等於 server URL(除非 caller 用 WithIssuer 覆蓋)
|
||||||
|
assert.Equal(t, srv.URL, doc["issuer"])
|
||||||
|
|
||||||
|
// 各 endpoint 都應指向 server URL
|
||||||
|
assert.Equal(t, srv.URL+"/authorize", doc["authorization_endpoint"])
|
||||||
|
assert.Equal(t, srv.URL+"/oauth/token", doc["token_endpoint"])
|
||||||
|
assert.Equal(t, srv.URL+"/jwks", doc["jwks_uri"])
|
||||||
|
|
||||||
|
// 必要支援列表
|
||||||
|
assert.Contains(t, doc["response_types_supported"], "code")
|
||||||
|
assert.Contains(t, doc["id_token_signing_alg_values_supported"], "RS256")
|
||||||
|
assert.Contains(t, doc["code_challenge_methods_supported"], "S256")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Discovery_RespectsWithIssuer(t *testing.T) {
|
||||||
|
const customIssuer = "https://example.com/custom-issuer"
|
||||||
|
srv := NewServer(t, WithIssuer(customIssuer))
|
||||||
|
|
||||||
|
resp, err := http.Get(srv.URL + "/.well-known/openid-configuration")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
var doc map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&doc))
|
||||||
|
assert.Equal(t, customIssuer, doc["issuer"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── JWKS ─────────────────────────
|
||||||
|
|
||||||
|
// TestServer_JWKS_CanVerifyServerSignedToken 確認 JWKS endpoint 公佈的 public key
|
||||||
|
// 真的能驗 fake server 用 IssueIDToken 簽出來的 token。這是 fake server 整體 contract
|
||||||
|
// 中「最關鍵」的一環 — 一旦這條路斷掉,所有 e2e test 都會失敗。
|
||||||
|
func TestServer_JWKS_CanVerifyServerSignedToken(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
// 1. 從 JWKS 取 public key
|
||||||
|
resp, err := http.Get(srv.URL + "/jwks")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var jwks struct {
|
||||||
|
Keys []map[string]any `json:"keys"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&jwks))
|
||||||
|
require.Len(t, jwks.Keys, 1)
|
||||||
|
|
||||||
|
jwk := jwks.Keys[0]
|
||||||
|
assert.Equal(t, "RSA", jwk["kty"])
|
||||||
|
assert.Equal(t, "RS256", jwk["alg"])
|
||||||
|
assert.Equal(t, srv.KeyID, jwk["kid"])
|
||||||
|
|
||||||
|
// 2. 把 JWK reconstruct 成 *rsa.PublicKey
|
||||||
|
pub := jwkToRSAPublicKey(t, jwk)
|
||||||
|
|
||||||
|
// 3. 請 server 簽一個 token
|
||||||
|
tok, err := srv.IssueIDToken(map[string]any{
|
||||||
|
"sub": "verify-test",
|
||||||
|
"iss": srv.URL,
|
||||||
|
"aud": srv.ClientID,
|
||||||
|
"exp": time.Now().Add(time.Minute).Unix(),
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 4. 用 jose 驗簽
|
||||||
|
parsed, err := jwt.ParseSigned(tok, []jose.SignatureAlgorithm{jose.RS256})
|
||||||
|
require.NoError(t, err, "簽出的 token 應為合法 JWS")
|
||||||
|
|
||||||
|
var out map[string]any
|
||||||
|
require.NoError(t, parsed.Claims(pub, &out), "JWKS 公開的 public key 應能驗 server 自己簽的 token")
|
||||||
|
assert.Equal(t, "verify-test", out["sub"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── /oauth/token ─────────────────────────
|
||||||
|
|
||||||
|
func TestServer_Token_RejectsWrongClientSecret(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
// 先模擬 authorize 拿一個 code(避免 invalid_grant 先擋掉)
|
||||||
|
verifier := "verifier-xyz-1234567890123456789012345"
|
||||||
|
challenge := pkceS256(verifier)
|
||||||
|
code, err := srv.IssueAuthCode(challenge, "S256", "n-1", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 用「錯的 client_secret」打 token endpoint
|
||||||
|
resp := postToken(t, srv, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"client_secret": {"wrong-secret"},
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
|
||||||
|
var errBody map[string]string
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&errBody))
|
||||||
|
assert.Equal(t, "invalid_client", errBody["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Token_AcceptsBasicAuth(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
verifier := "verifier-basicauth-12345678901234567890123"
|
||||||
|
challenge := pkceS256(verifier)
|
||||||
|
code, err := srv.IssueAuthCode(challenge, "S256", "n-2", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest(http.MethodPost, srv.URL+"/oauth/token",
|
||||||
|
strings.NewReader(form.Encode()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.SetBasicAuth(srv.ClientID, srv.ClientSecret)
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode, "Basic auth 應被接受")
|
||||||
|
|
||||||
|
var tokResp map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&tokResp))
|
||||||
|
assert.NotEmpty(t, tokResp["id_token"])
|
||||||
|
assert.NotEmpty(t, tokResp["access_token"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Token_PKCEMatch(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
verifier := "verifier-good-1234567890abcdefghij1234567"
|
||||||
|
challenge := pkceS256(verifier)
|
||||||
|
code, err := srv.IssueAuthCode(challenge, "S256", "nonce-good", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := postToken(t, srv, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"client_secret": {srv.ClientSecret},
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var tokResp map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&tokResp))
|
||||||
|
assert.NotEmpty(t, tokResp["id_token"])
|
||||||
|
|
||||||
|
// nonce 應被灌到 id_token claims 裡
|
||||||
|
idToken := tokResp["id_token"].(string)
|
||||||
|
claims := decodeJWTPayload(t, idToken)
|
||||||
|
assert.Equal(t, "nonce-good", claims["nonce"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Token_PKCEMismatch(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
correct := "verifier-A-1234567890abcdefghij12345678"
|
||||||
|
wrong := "verifier-B-1234567890abcdefghij12345678"
|
||||||
|
|
||||||
|
code, err := srv.IssueAuthCode(pkceS256(correct), "S256", "n-3", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := postToken(t, srv, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {wrong}, // 故意錯
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"client_secret": {srv.ClientSecret},
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
|
||||||
|
var errBody map[string]string
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&errBody))
|
||||||
|
assert.Equal(t, "invalid_grant", errBody["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Token_CodeIsOneTimeUse(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
verifier := "verifier-once-1234567890abcdefghij12345"
|
||||||
|
code, err := srv.IssueAuthCode(pkceS256(verifier), "S256", "n-4", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
form := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"client_secret": {srv.ClientSecret},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次:成功
|
||||||
|
r1 := postToken(t, srv, form)
|
||||||
|
require.Equal(t, http.StatusOK, r1.StatusCode)
|
||||||
|
r1.Body.Close()
|
||||||
|
|
||||||
|
// 第二次:同個 code,invalid_grant
|
||||||
|
r2 := postToken(t, srv, form)
|
||||||
|
defer r2.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadRequest, r2.StatusCode)
|
||||||
|
|
||||||
|
var errBody map[string]string
|
||||||
|
require.NoError(t, json.NewDecoder(r2.Body).Decode(&errBody))
|
||||||
|
assert.Equal(t, "invalid_grant", errBody["error"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Token_AppliesNextIDTokenClaims(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
srv.SetNextIDTokenClaims(map[string]any{
|
||||||
|
"sub": "user-overridden",
|
||||||
|
"email": "override@example.com",
|
||||||
|
"name": "Override User",
|
||||||
|
})
|
||||||
|
|
||||||
|
verifier := "verifier-claims-1234567890abcdefghij1234"
|
||||||
|
code, err := srv.IssueAuthCode(pkceS256(verifier), "S256", "nonce-claims", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := postToken(t, srv, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"client_secret": {srv.ClientSecret},
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var tokResp map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&tokResp))
|
||||||
|
|
||||||
|
claims := decodeJWTPayload(t, tokResp["id_token"].(string))
|
||||||
|
assert.Equal(t, "user-overridden", claims["sub"])
|
||||||
|
assert.Equal(t, "override@example.com", claims["email"])
|
||||||
|
assert.Equal(t, "Override User", claims["name"])
|
||||||
|
assert.Equal(t, "nonce-claims", claims["nonce"], "nonce 仍會自動補入")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── /authorize ─────────────────────────
|
||||||
|
|
||||||
|
func TestServer_Authorize_RedirectsWithCodeAndState(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
authorizeURL := srv.URL + "/authorize?" + url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {srv.ClientID},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"scope": {"openid email profile"},
|
||||||
|
"state": {"state-abc"},
|
||||||
|
"code_challenge": {pkceS256("any-verifier")},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
"nonce": {"nonce-abc"},
|
||||||
|
}.Encode()
|
||||||
|
|
||||||
|
cb := srv.SimulateAuthorizationFlow(t, authorizeURL)
|
||||||
|
|
||||||
|
u, err := url.Parse(cb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// callback 應為 redirect_uri,帶 code & state
|
||||||
|
assert.Equal(t, "http", u.Scheme)
|
||||||
|
assert.Equal(t, "localhost:8080", u.Host)
|
||||||
|
assert.Equal(t, "/api/auth/callback", u.Path)
|
||||||
|
assert.NotEmpty(t, u.Query().Get("code"))
|
||||||
|
assert.Equal(t, "state-abc", u.Query().Get("state"))
|
||||||
|
|
||||||
|
// LastAuthorizeQuery 應記下 caller 帶的 PKCE / nonce
|
||||||
|
q := srv.LastAuthorizeQuery()
|
||||||
|
assert.Equal(t, "state-abc", q.Get("state"))
|
||||||
|
assert.Equal(t, "nonce-abc", q.Get("nonce"))
|
||||||
|
assert.Equal(t, "S256", q.Get("code_challenge_method"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_Authorize_MissingRedirectURIReturns400(t *testing.T) {
|
||||||
|
srv := NewServer(t)
|
||||||
|
|
||||||
|
// 沒帶 redirect_uri
|
||||||
|
resp, err := http.Get(srv.URL + "/authorize?client_id=" + srv.ClientID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── WithClientCredentials ─────────────────────────
|
||||||
|
|
||||||
|
func TestServer_WithClientCredentials(t *testing.T) {
|
||||||
|
srv := NewServer(t,
|
||||||
|
WithClientCredentials("custom-client", "custom-secret"),
|
||||||
|
)
|
||||||
|
assert.Equal(t, "custom-client", srv.ClientID)
|
||||||
|
assert.Equal(t, "custom-secret", srv.ClientSecret)
|
||||||
|
|
||||||
|
// 用預設 secret 應失敗
|
||||||
|
verifier := "verifier-cc-1234567890abcdefghij12345678"
|
||||||
|
code, err := srv.IssueAuthCode(pkceS256(verifier), "S256", "n", testRedirectURI)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
resp := postToken(t, srv, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {testRedirectURI},
|
||||||
|
"code_verifier": {verifier},
|
||||||
|
"client_id": {"custom-client"},
|
||||||
|
"client_secret": {"test-secret"}, // 預設值,不是 custom-secret
|
||||||
|
})
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ───────────────────────── helpers ─────────────────────────
|
||||||
|
|
||||||
|
func postToken(t *testing.T, srv *Server, form url.Values) *http.Response {
|
||||||
|
t.Helper()
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost,
|
||||||
|
srv.URL+"/oauth/token", strings.NewReader(form.Encode()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
// jwkToRSAPublicKey 把 JWK map reconstruct 成 *rsa.PublicKey。
|
||||||
|
// 這是 OB1 fakeOIDC 沒做的(OB1 直接用 coreos lib 內部把 JWKS 解出來),
|
||||||
|
// 我們用最少程式碼自己 decode 一次驗證 server 的 JWKS contract。
|
||||||
|
func jwkToRSAPublicKey(t *testing.T, jwk map[string]any) *rsa.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
nB64, _ := jwk["n"].(string)
|
||||||
|
eB64, _ := jwk["e"].(string)
|
||||||
|
require.NotEmpty(t, nB64)
|
||||||
|
require.NotEmpty(t, eB64)
|
||||||
|
|
||||||
|
nBytes, err := base64.RawURLEncoding.DecodeString(nB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
eBytes, err := base64.RawURLEncoding.DecodeString(eB64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// e 是大端 byte slice → int
|
||||||
|
e := 0
|
||||||
|
for _, b := range eBytes {
|
||||||
|
e = e<<8 | int(b)
|
||||||
|
}
|
||||||
|
n := new(big.Int).SetBytes(nBytes)
|
||||||
|
|
||||||
|
return &rsa.PublicKey{N: n, E: e}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeJWTPayload 取 JWT 中間段的 payload 解 JSON 出 claims。
|
||||||
|
// 不驗簽(呼叫者已在別處驗過簽章)。
|
||||||
|
func decodeJWTPayload(t *testing.T, tok string) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
parts := strings.Split(tok, ".")
|
||||||
|
require.Len(t, parts, 3, "JWT 應為 3 段 (header.payload.signature)")
|
||||||
|
raw, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
require.NoError(t, err)
|
||||||
|
var out map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(raw, &out))
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// 確保 io 沒被當成 unused(某些版本 lint 嚴苛)— 故意使用一次。
|
||||||
|
var _ = io.EOF
|
||||||
0
visionA-backend/internal/relay/.gitkeep
Normal file
0
visionA-backend/internal/relay/.gitkeep
Normal file
211
visionA-backend/internal/relay/integration_raw_test.go
Normal file
211
visionA-backend/internal/relay/integration_raw_test.go
Normal file
@ -0,0 +1,211 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestEndToEnd_RawForward 驗證 B3 Review Major 1 修復後新增的
|
||||||
|
// `POST /internal/forward/raw` endpoint 能走完 raw TCP forwarding 路徑,
|
||||||
|
// 並支援 streaming body(MJPEG / chunked)。
|
||||||
|
//
|
||||||
|
// 路徑:
|
||||||
|
//
|
||||||
|
// fake api-server (raw TCP dial)
|
||||||
|
// └─► POST /internal/forward/raw ──► remote-proxy internal server
|
||||||
|
// └─► Hijack + OpenStream + 雙向 io.Copy
|
||||||
|
// └─► fake tunnel client (yamux stream)
|
||||||
|
// └─► fake local server(chunked response)
|
||||||
|
//
|
||||||
|
// 驗證重點:
|
||||||
|
// 1. 「HTTP/1.1 200 Connected」握手成功
|
||||||
|
// 2. 完整 HTTP request 能寫進 hijacked 連線 → local server 收到
|
||||||
|
// 3. Response status / headers / body 能正確回來
|
||||||
|
// 4. Chunked streaming body 的 trailing chunks 也能收完(不像 JSON 版會一次收完)
|
||||||
|
func TestEndToEnd_RawForward(t *testing.T) {
|
||||||
|
// 1. Fake local server — 回 chunked streaming body 模擬 MJPEG / SSE
|
||||||
|
localSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("X-Test-Route", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
flusher, _ := w.(http.Flusher)
|
||||||
|
// 送 3 個 chunk 模擬 streaming
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
fmt.Fprintf(w, "data: chunk-%d\n\n", i)
|
||||||
|
if flusher != nil {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer localSrv.Close()
|
||||||
|
localAddr := strings.TrimPrefix(localSrv.URL, "http://")
|
||||||
|
|
||||||
|
// 2. 起 remote-proxy(tunnel + internal server)
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
relaySrv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
internalSrv := NewInternalServer(store, slog.Default())
|
||||||
|
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect)
|
||||||
|
tunnelTS := httptest.NewServer(tunnelMux)
|
||||||
|
defer tunnelTS.Close()
|
||||||
|
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(internalMux)
|
||||||
|
internalTS := httptest.NewServer(internalMux)
|
||||||
|
defer internalTS.Close()
|
||||||
|
|
||||||
|
// 3. Fake tunnel client — 把 stream 收到的 HTTP request 真 TCP 轉發給 localSrv
|
||||||
|
const token = "vAc_cafecafecafecafecafecafecafecafe"
|
||||||
|
stopTunnel := startTunnelClientForwardingTo(t, tunnelTS.URL, token, localAddr)
|
||||||
|
defer stopTunnel()
|
||||||
|
|
||||||
|
// 4. 等 session register
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 5. 模擬 api-server 端:raw TCP dial → hijack 握手 → 送 HTTP request → 讀 response
|
||||||
|
conn := dialRawForward(t, internalTS.URL, token)
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// 送一個真正的 HTTP GET / (走完整的 RFC 7230 格式,local agent 要會 parse)
|
||||||
|
reqLine := "GET /api/stream HTTP/1.1\r\n" +
|
||||||
|
"Host: 127.0.0.1\r\n" +
|
||||||
|
"X-From-Api-Server: raw-test\r\n" +
|
||||||
|
"Accept: text/event-stream\r\n" +
|
||||||
|
"\r\n"
|
||||||
|
_, werr := conn.Write([]byte(reqLine))
|
||||||
|
require.NoError(t, werr)
|
||||||
|
|
||||||
|
// 讀 response — 用 http.ReadResponse 解析 chunked body
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
httpReq, _ := http.NewRequest(http.MethodGet, "/api/stream", nil)
|
||||||
|
resp, rerr := http.ReadResponse(reader, httpReq)
|
||||||
|
require.NoError(t, rerr)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
assert.Equal(t, "/api/stream", resp.Header.Get("X-Test-Route"),
|
||||||
|
"response header 應該被原封轉發")
|
||||||
|
assert.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
// 讀 streaming body — 驗證三個 chunk 都收到
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
bodyStr := string(body)
|
||||||
|
assert.Contains(t, bodyStr, "data: chunk-0")
|
||||||
|
assert.Contains(t, bodyStr, "data: chunk-1")
|
||||||
|
assert.Contains(t, bodyStr, "data: chunk-2")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndToEnd_RawForward_TunnelDisconnected 當 token 不存在時,
|
||||||
|
// raw forward endpoint 應在 hijack 前回 502 JSON(可被一般 HTTP client 讀到)。
|
||||||
|
func TestEndToEnd_RawForward_TunnelDisconnected(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
internalSrv := NewInternalServer(store, slog.Default())
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(mux)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// 用一般 http client 打(沒 session 時還沒 hijack,會回一般 JSON response)
|
||||||
|
resp, err := http.Post(
|
||||||
|
ts.URL+"/internal/forward/raw?token=vAc_dddddddddddddddddddddddddddddddd",
|
||||||
|
"application/octet-stream",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndToEnd_RawForward_MissingToken 沒帶 token 應回 400。
|
||||||
|
func TestEndToEnd_RawForward_MissingToken(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
internalSrv := NewInternalServer(store, slog.Default())
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(mux)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
resp, err := http.Post(ts.URL+"/internal/forward/raw", "application/octet-stream", nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// dialRawForward 模擬 api-server 端:raw TCP dial remote-proxy,
|
||||||
|
// 發一個帶 token 的 POST /internal/forward/raw 請求,讀取 "HTTP/1.1 200 Connected"
|
||||||
|
// 握手回應,然後回傳這條已經「接管為 raw TCP」的連線供 caller 直接 io 使用。
|
||||||
|
//
|
||||||
|
// 對齊 `HandleForwardRaw` 的協議。
|
||||||
|
func dialRawForward(t *testing.T, internalURL, token string) net.Conn {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
u, err := url.Parse(internalURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// TCP dial
|
||||||
|
conn, err := net.DialTimeout("tcp", u.Host, 5*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// 送 HTTP POST request(含 token query)
|
||||||
|
reqLine := fmt.Sprintf(
|
||||||
|
"POST /internal/forward/raw?token=%s HTTP/1.1\r\n"+
|
||||||
|
"Host: %s\r\n"+
|
||||||
|
"Content-Length: 0\r\n"+
|
||||||
|
"\r\n",
|
||||||
|
token, u.Host,
|
||||||
|
)
|
||||||
|
_, werr := conn.Write([]byte(reqLine))
|
||||||
|
require.NoError(t, werr)
|
||||||
|
|
||||||
|
// 讀握手行 — 預期 "HTTP/1.1 200 Connected\r\n\r\n"
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
statusLine, err := reader.ReadString('\n')
|
||||||
|
require.NoError(t, err, "failed to read status line")
|
||||||
|
require.Contains(t, statusLine, "200 Connected",
|
||||||
|
"expected 200 Connected, got: %q", statusLine)
|
||||||
|
|
||||||
|
// 讀掉空白行(header 結束)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
require.NoError(t, err)
|
||||||
|
if line == "\r\n" || line == "\n" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Buffer 裡可能還有 reader 預讀的資料 — 不影響,因為後續我們會用 bufio.NewReader(conn) 再讀
|
||||||
|
// 但為了避免 reader 裡殘留的預讀資料被吃掉,caller 要自己管 bufio.NewReader
|
||||||
|
// 這裡回傳 conn;caller 在 Write request 後,要用新的 bufio.NewReader(conn) 讀 response
|
||||||
|
//
|
||||||
|
// 注意:實務上 reader.Buffered() 應該是 0(server 還沒送 response body),
|
||||||
|
// 所以直接回 conn 即可。
|
||||||
|
assert.Equal(t, 0, reader.Buffered(),
|
||||||
|
"reader 不應該有預讀資料;若有則 caller 必須用此 reader 而非新建 bufio")
|
||||||
|
|
||||||
|
return conn
|
||||||
|
}
|
||||||
243
visionA-backend/internal/relay/integration_test.go
Normal file
243
visionA-backend/internal/relay/integration_test.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/wsconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestEndToEnd_ForwardFromInternalToFakeLocal 模擬雛形完整 forwarding 路徑:
|
||||||
|
//
|
||||||
|
// api-server (internal HTTP client)
|
||||||
|
// └─► POST /internal/forward/http ──► remote-proxy (this test's HTTP server)
|
||||||
|
// └─► OpenStream 透過 tunnel ──► fake local agent
|
||||||
|
// └─► 轉發到 fake local server (in-process httptest.Server)
|
||||||
|
//
|
||||||
|
// 這是 B3 任務 prompt 要求的「integration test」:通則代表整個 tunnel forwarding
|
||||||
|
// 路徑可用,B4(api-server)與 B5(API handlers)可以安心往上疊。
|
||||||
|
//
|
||||||
|
// 比對點:
|
||||||
|
// - 請求能走完 internal → tunnel → local agent → local server
|
||||||
|
// - local server 的 response body 能被讀回 api-server 端(base64 解碼)
|
||||||
|
// - HTTP headers / status / body 都能保留
|
||||||
|
func TestEndToEnd_ForwardFromInternalToFakeLocal(t *testing.T) {
|
||||||
|
// 1. 起一個 fake local server(模擬 `127.0.0.1:3721`)
|
||||||
|
localSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Test-Route", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"method": r.Method,
|
||||||
|
"path": r.URL.Path,
|
||||||
|
"ok": true,
|
||||||
|
"receivedHeader": r.Header.Get("X-From-Api-Server"),
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer localSrv.Close()
|
||||||
|
localAddr := strings.TrimPrefix(localSrv.URL, "http://")
|
||||||
|
|
||||||
|
// 2. 起 remote-proxy 的兩個 server(tunnel + internal)
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
relaySrv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
internalSrv := NewInternalServer(store, slog.Default())
|
||||||
|
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect)
|
||||||
|
tunnelTS := httptest.NewServer(tunnelMux)
|
||||||
|
defer tunnelTS.Close()
|
||||||
|
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(internalMux)
|
||||||
|
internalTS := httptest.NewServer(internalMux)
|
||||||
|
defer internalTS.Close()
|
||||||
|
|
||||||
|
// 3. 起 fake tunnel client(模擬 POC edge-ai-server 或未來的 local agent),
|
||||||
|
// 把 tunnel stream 上收到的 HTTP request 轉發給 localSrv。
|
||||||
|
const token = "vAc_feedbeeffeedbeeffeedbeeffeedbeef"
|
||||||
|
stop := startTunnelClientForwardingTo(t, tunnelTS.URL, token, localAddr)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// 4. 等 session register
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 5. 透過 /internal/forward/http 送一個請求
|
||||||
|
reqBody := ForwardHTTPRequest{
|
||||||
|
SessionToken: token,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/devices",
|
||||||
|
Headers: map[string]string{
|
||||||
|
"X-From-Api-Server": "test-value",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
bb, _ := json.Marshal(reqBody)
|
||||||
|
resp, err := http.Post(internalTS.URL+"/internal/forward/http",
|
||||||
|
"application/json", bytes.NewReader(bb))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var fr ForwardHTTPResponse
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&fr))
|
||||||
|
require.Nil(t, fr.Error, "forward should not error: %+v", fr.Error)
|
||||||
|
require.Equal(t, http.StatusOK, fr.Status)
|
||||||
|
|
||||||
|
// 6. 解 body 驗證
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(fr.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
var payload map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal(decoded, &payload))
|
||||||
|
assert.Equal(t, "GET", payload["method"])
|
||||||
|
assert.Equal(t, "/api/devices", payload["path"])
|
||||||
|
assert.Equal(t, true, payload["ok"])
|
||||||
|
assert.Equal(t, "test-value", payload["receivedHeader"],
|
||||||
|
"X-From-Api-Server header 應該被保留到 fake local server")
|
||||||
|
|
||||||
|
// 7. 驗證 response header 也被保留
|
||||||
|
if vals, ok := fr.Headers["X-Test-Route"]; ok {
|
||||||
|
assert.Equal(t, "/api/devices", vals[0])
|
||||||
|
} else {
|
||||||
|
t.Errorf("expected X-Test-Route header to propagate; got: %+v", fr.Headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestEndToEnd_Forward_TunnelDisconnected 當 token 沒對應 session 時回 error。
|
||||||
|
func TestEndToEnd_Forward_TunnelDisconnected(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
internalSrv := NewInternalServer(store, slog.Default())
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
internalSrv.Routes(mux)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
reqBody := ForwardHTTPRequest{
|
||||||
|
SessionToken: "vAc_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
bb, _ := json.Marshal(reqBody)
|
||||||
|
resp, err := http.Post(ts.URL+"/internal/forward/http",
|
||||||
|
"application/json", bytes.NewReader(bb))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
|
||||||
|
var fr ForwardHTTPResponse
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&fr))
|
||||||
|
require.NotNil(t, fr.Error)
|
||||||
|
assert.Equal(t, "TUNNEL_DISCONNECTED", fr.Error.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSessionCleanup_RemovesOnDisconnect 驗證當 local agent 斷線時
|
||||||
|
// session 會從 store 移除(tunnel.md §2.2 斷線處理)。
|
||||||
|
func TestSessionCleanup_RemovesOnDisconnect(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
relaySrv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
const token = "vAc_bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"
|
||||||
|
stop := startFakeLocalAgent(t, "ws"+strings.TrimPrefix(ts.URL, "http")+"/tunnel/connect",
|
||||||
|
token, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), token)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
stop() // 斷線
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), token)
|
||||||
|
return !ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond, "disconnected session should be unregistered")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Helper:起一個 fake tunnel client,把 tunnel stream 上收到的 HTTP request
|
||||||
|
// 透過真實 TCP 轉發給 localAddr(完整模擬 local agent 行為)
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
func startTunnelClientForwardingTo(t *testing.T, relayHTTPURL, token, localAddr string) func() {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(relayHTTPURL, "http") + "/tunnel/connect"
|
||||||
|
u, err := url.Parse(wsURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("token", token)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
rawWS, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
netConn := wsconn.New(rawWS)
|
||||||
|
ym, err := yamux.Client(netConn, yamux.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
stream, aerr := ym.Accept()
|
||||||
|
if aerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func(s net.Conn) {
|
||||||
|
defer s.Close()
|
||||||
|
req, rerr := http.ReadRequest(bufio.NewReader(s))
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 把 request forward 到 localAddr(真 TCP)
|
||||||
|
req.URL.Scheme = "http"
|
||||||
|
req.URL.Host = localAddr
|
||||||
|
req.Host = localAddr
|
||||||
|
req.RequestURI = ""
|
||||||
|
|
||||||
|
resp, rerr := http.DefaultTransport.RoundTrip(req)
|
||||||
|
if rerr != nil {
|
||||||
|
errResp := &http.Response{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
ProtoMajor: 1,
|
||||||
|
ProtoMinor: 1,
|
||||||
|
Header: make(http.Header),
|
||||||
|
Body: io.NopCloser(bytes.NewReader(nil)),
|
||||||
|
}
|
||||||
|
_ = errResp.Write(s)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
_ = resp.Write(s)
|
||||||
|
}(stream)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
_ = ym.Close()
|
||||||
|
_ = rawWS.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
356
visionA-backend/internal/relay/internal_forward.go
Normal file
356
visionA-backend/internal/relay/internal_forward.go
Normal file
@ -0,0 +1,356 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InternalServer 提供 api-server → remote-proxy 的 internal HTTP API。
|
||||||
|
//
|
||||||
|
// 對齊 `.autoflow/04-architecture/api/api-internal.md`:
|
||||||
|
// - POST /internal/forward/http — api-server 轉發非 WS 請求給指定 session
|
||||||
|
// - GET /internal/forward/ws — api-server 轉發 WS upgrade(Phase 0 暫為 stub)
|
||||||
|
// - GET /internal/session/:token — 查 session 是否存在與基本資訊
|
||||||
|
// - GET /internal/sessions — 列出所有在線 session(debug / metrics 用)
|
||||||
|
// - POST /internal/session/:token/close — 後台運維強制斷 tunnel
|
||||||
|
//
|
||||||
|
// 雛形安全性:只監聽 internal port(`VISIONA_PROXY_INTERNAL_PORT`,預設 3801),
|
||||||
|
// 生產環境須由網路層(security group / NetworkPolicy)阻擋外部存取。
|
||||||
|
// Phase 1 再加 mTLS / shared secret(見 api-internal.md §安全)。
|
||||||
|
type InternalServer struct {
|
||||||
|
store session.Store
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInternalServer 建立 internal HTTP handler。
|
||||||
|
func NewInternalServer(store session.Store, logger *slog.Logger) *InternalServer {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return &InternalServer{store: store, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Routes 把所有 internal endpoints 註冊到 mux;caller 自行決定 listen 埠號。
|
||||||
|
//
|
||||||
|
// 兩個 forward endpoint 並存(B3 Review Major 1 修復):
|
||||||
|
// - `POST /internal/forward/http` — JSON + base64 封裝,適合簡單 JSON request/response(如 GET /healthz)
|
||||||
|
// - `POST /internal/forward/raw` — hijack 成 raw TCP,支援 streaming(MJPEG / SSE)、長連線、任意 HTTP
|
||||||
|
// 上 WS upgrade;`session.ProxyClient.OpenStream(ctx) net.Conn` 的真實底層(B4 用)
|
||||||
|
//
|
||||||
|
// 詳見 `.autoflow/04-architecture/api/api-internal.md`。
|
||||||
|
func (s *InternalServer) Routes(mux *http.ServeMux) {
|
||||||
|
mux.HandleFunc("/internal/forward/http", s.HandleForwardHTTP)
|
||||||
|
mux.HandleFunc("/internal/forward/raw", s.HandleForwardRaw)
|
||||||
|
mux.HandleFunc("/internal/forward/ws", s.HandleForwardWS)
|
||||||
|
mux.HandleFunc("/internal/session/", s.handleSessionByToken) // 含 :token 與 :token/close
|
||||||
|
mux.HandleFunc("/internal/sessions", s.HandleListSessions)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardHTTPRequest 是 api-server 丟給 /internal/forward/http 的請求 body(JSON)。
|
||||||
|
//
|
||||||
|
// 為了讓雛形簡單易測,我們採用 **JSON 結構化封裝** 的方式傳遞,而非 api-internal.md
|
||||||
|
// 所描述的「raw HTTP bytes」。兩者等價,未來可在不影響 API handler 的情況下切換。
|
||||||
|
// 這個 JSON 格式是 Phase 0 雛形的便利選擇(見 task B3 prompt)。
|
||||||
|
type ForwardHTTPRequest struct {
|
||||||
|
// SessionToken 是 local agent tunnel 的 token。
|
||||||
|
SessionToken string `json:"session_token"`
|
||||||
|
// Method 例:GET / POST。
|
||||||
|
Method string `json:"method"`
|
||||||
|
// Path 例:/api/devices(不含 scheme + host;local agent 會自己補)。
|
||||||
|
Path string `json:"path"`
|
||||||
|
// Headers 要帶的 HTTP headers。
|
||||||
|
Headers map[string]string `json:"headers,omitempty"`
|
||||||
|
// Body 是 base64 編碼的 request body;空字串 → 無 body。
|
||||||
|
Body string `json:"body,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardHTTPResponse 是 /internal/forward/http 的回應 body(JSON)。
|
||||||
|
type ForwardHTTPResponse struct {
|
||||||
|
Status int `json:"status"`
|
||||||
|
Headers map[string][]string `json:"headers,omitempty"`
|
||||||
|
// Body 是 base64 編碼的 response body。
|
||||||
|
Body string `json:"body,omitempty"`
|
||||||
|
// Error 在轉發過程失敗時填寫(tunnel 斷、stream 失敗等)。
|
||||||
|
Error *ForwardHTTPError `json:"error,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardHTTPError 描述轉發失敗的原因。
|
||||||
|
type ForwardHTTPError struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleForwardHTTP 實作 POST /internal/forward/http。
|
||||||
|
//
|
||||||
|
// 雛形採 JSON 格式(ForwardHTTPRequest/Response),適合簡單的一次性 JSON request/response。
|
||||||
|
// **不支援 streaming(MJPEG / SSE / chunked)與 WebSocket**;這類呼叫應改走
|
||||||
|
// `POST /internal/forward/raw`(B3 Review Major 1 修復後新增的 raw bytes endpoint)。
|
||||||
|
//
|
||||||
|
// 為何保留兩條路徑?
|
||||||
|
// - JSON 版:api-server 對簡單 API(如 GET /healthz、POST /api/devices)好寫好測
|
||||||
|
// - Raw 版:`session.ProxyClient.OpenStream(ctx) net.Conn` 的底層,streaming friendly
|
||||||
|
//
|
||||||
|
// 注意:此 handler 不支援 Flusher 串流回寫(JSON 封裝本質上不能串流)。
|
||||||
|
func (s *InternalServer) HandleForwardHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeJSONError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "POST required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req ForwardHTTPRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeJSONError(w, http.StatusBadRequest, "INVALID_JSON", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.SessionToken == "" {
|
||||||
|
writeJSONError(w, http.StatusBadRequest, "MISSING_SESSION_TOKEN", "session_token required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Method == "" {
|
||||||
|
req.Method = http.MethodGet
|
||||||
|
}
|
||||||
|
if req.Path == "" {
|
||||||
|
req.Path = "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, ferr := forwardOverTunnel(r.Context(), s.store, req)
|
||||||
|
if ferr != nil {
|
||||||
|
s.logger.Warn("internal forward failed",
|
||||||
|
"error", ferr.Error(),
|
||||||
|
"token_prefix", tokenPrefix(req.SessionToken),
|
||||||
|
"method", req.Method,
|
||||||
|
"path", req.Path)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if resp.Error != nil {
|
||||||
|
// 轉發錯誤 → 回 502(同 api-internal.md TUNNEL_DISCONNECTED 語意)
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleForwardWS 是 /internal/forward/ws 的 stub(Phase 0 雛形)。
|
||||||
|
//
|
||||||
|
// 完整實作需要 Hijack + WebSocket relay,與 HandleProxy.proxyWebSocket 類似;
|
||||||
|
// 為了不過度複雜化 B3,這裡先回 501;真正的 WS forward 在 B5 接入前端時補齊。
|
||||||
|
func (s *InternalServer) HandleForwardWS(w http.ResponseWriter, r *http.Request) {
|
||||||
|
writeJSONError(w, http.StatusNotImplemented, "NOT_IMPLEMENTED",
|
||||||
|
"WS forward stub — will be implemented in B5 when frontend connects")
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSessionByToken 分派 /internal/session/:token 與 /internal/session/:token/close。
|
||||||
|
func (s *InternalServer) handleSessionByToken(w http.ResponseWriter, r *http.Request) {
|
||||||
|
const prefix = "/internal/session/"
|
||||||
|
rest := r.URL.Path[len(prefix):]
|
||||||
|
if rest == "" {
|
||||||
|
writeJSONError(w, http.StatusBadRequest, "MISSING_TOKEN", "token required in path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// /internal/session/:token/close
|
||||||
|
if idx := indexByte(rest, '/'); idx != -1 {
|
||||||
|
token := rest[:idx]
|
||||||
|
action := rest[idx+1:]
|
||||||
|
if action == "close" && r.Method == http.MethodPost {
|
||||||
|
s.closeSession(w, r, token)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSONError(w, http.StatusNotFound, "NOT_FOUND", "unknown action: "+action)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// /internal/session/:token
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
writeJSONError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "GET required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.getSession(w, r, rest)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InternalServer) getSession(w http.ResponseWriter, r *http.Request, token string) {
|
||||||
|
h, err := s.store.Lookup(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, http.StatusNotFound, "NOT_FOUND", "session not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sum := h.Summary()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"token": sum.Token,
|
||||||
|
"connected": !h.IsClosed(),
|
||||||
|
"connected_at": sum.ConnectedAt,
|
||||||
|
"last_heartbeat": sum.LastHeartbeat,
|
||||||
|
"remote_addr": sum.RemoteAddr,
|
||||||
|
"user_id": sum.UserID,
|
||||||
|
"device_id": sum.DeviceID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *InternalServer) closeSession(w http.ResponseWriter, r *http.Request, token string) {
|
||||||
|
h, err := s.store.Lookup(r.Context(), token)
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, http.StatusNotFound, "NOT_FOUND", "session not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = h.Close()
|
||||||
|
_ = s.store.Unregister(r.Context(), token)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"closed": true})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleListSessions 實作 GET /internal/sessions。
|
||||||
|
func (s *InternalServer) HandleListSessions(w http.ResponseWriter, r *http.Request) {
|
||||||
|
summaries, err := s.store.List(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"sessions": summaries,
|
||||||
|
"total": len(summaries),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// forwardOverTunnel 是 /internal/forward/http 的核心:
|
||||||
|
// 1. store.Lookup 找到 handle
|
||||||
|
// 2. OpenStream
|
||||||
|
// 3. 在 stream 上組 HTTP request 並讀取 response
|
||||||
|
// 4. 封裝回 ForwardHTTPResponse
|
||||||
|
func forwardOverTunnel(ctx context.Context, store session.Store, req ForwardHTTPRequest) (ForwardHTTPResponse, error) {
|
||||||
|
h, err := store.Lookup(ctx, req.SessionToken)
|
||||||
|
if err != nil || h.IsClosed() {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "TUNNEL_DISCONNECTED",
|
||||||
|
Message: "session not connected",
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := h.OpenStream(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "TUNNEL_ERROR",
|
||||||
|
Message: "open stream failed: " + err.Error(),
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
bodyBytes, err := decodeBase64(req.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "INVALID_BODY",
|
||||||
|
Message: "body base64 decode failed",
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpReq, err := http.NewRequest(req.Method, req.Path, bytesReader(bodyBytes))
|
||||||
|
if err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "INVALID_REQUEST",
|
||||||
|
Message: "build request failed: " + err.Error(),
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
// local agent 會自行覆寫 Host;這裡只保留 "127.0.0.1"(placeholder)
|
||||||
|
httpReq.URL.Scheme = "http"
|
||||||
|
httpReq.URL.Host = "127.0.0.1"
|
||||||
|
httpReq.RequestURI = ""
|
||||||
|
httpReq.Host = "127.0.0.1"
|
||||||
|
if len(bodyBytes) > 0 {
|
||||||
|
httpReq.ContentLength = int64(len(bodyBytes))
|
||||||
|
}
|
||||||
|
for k, v := range req.Headers {
|
||||||
|
httpReq.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
// 設定 Close=false 保留長連線語意由 yamux / local agent 決定
|
||||||
|
httpReq.Close = false
|
||||||
|
|
||||||
|
if err := httpReq.Write(stream); err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "TUNNEL_WRITE_ERROR",
|
||||||
|
Message: "write request to tunnel failed: " + err.Error(),
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
httpResp, err := http.ReadResponse(bufio.NewReader(stream), httpReq)
|
||||||
|
if err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "TUNNEL_READ_ERROR",
|
||||||
|
Message: "read response from tunnel failed: " + err.Error(),
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
defer httpResp.Body.Close()
|
||||||
|
|
||||||
|
respBody, err := io.ReadAll(httpResp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Error: &ForwardHTTPError{
|
||||||
|
Code: "TUNNEL_READ_ERROR",
|
||||||
|
Message: "read response body failed: " + err.Error(),
|
||||||
|
},
|
||||||
|
}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ForwardHTTPResponse{
|
||||||
|
Status: httpResp.StatusCode,
|
||||||
|
Headers: httpResp.Header,
|
||||||
|
Body: encodeBase64(respBody),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// indexByte 找 rune b 在 s 的第一個位置;沒找到回 -1。
|
||||||
|
// 另寫是為了不 import strings / bytes,避免 forwardOverTunnel 附近多一個 dep。
|
||||||
|
func indexByte(s string, b byte) int {
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
if s[i] == b {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
|
||||||
|
// bytesReader 建立 io.Reader;body 為 nil / 空時回 nil(http.NewRequest 接受)。
|
||||||
|
func bytesReader(body []byte) io.Reader {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return bytes.NewReader(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeBase64 解碼 Forward request body;空字串回 nil。
|
||||||
|
func decodeBase64(s string) ([]byte, error) {
|
||||||
|
if s == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.DecodeString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeBase64 編碼 Forward response body。
|
||||||
|
func encodeBase64(b []byte) string {
|
||||||
|
if len(b) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return base64.StdEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
143
visionA-backend/internal/relay/internal_forward_raw.go
Normal file
143
visionA-backend/internal/relay/internal_forward_raw.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HandleForwardRaw 實作 POST /internal/forward/raw — 對齊 api-internal.md §POST /internal/forward/http
|
||||||
|
// 所描述的「raw HTTP bytes」行為。
|
||||||
|
//
|
||||||
|
// 與 `HandleForwardHTTP`(JSON + base64 封裝)的差別:
|
||||||
|
//
|
||||||
|
// ┌──────────────────────┬──────────────┬──────────────────────┐
|
||||||
|
// │ │ /forward/http │ /forward/raw │
|
||||||
|
// ├──────────────────────┼──────────────┼──────────────────────┤
|
||||||
|
// │ request 封裝 │ JSON + base64 │ hijack 成 raw TCP │
|
||||||
|
// │ 支援 streaming body │ ❌ │ ✅ │
|
||||||
|
// │ 支援 MJPEG / SSE │ ❌ │ ✅ │
|
||||||
|
// │ 支援 WebSocket-like │ ❌ │ ✅(只要走 HTTP bytes)│
|
||||||
|
// │ 適合場景 │ 簡單 JSON API │ ProxyClient.OpenStream│
|
||||||
|
// └──────────────────────┴──────────────┴──────────────────────┘
|
||||||
|
//
|
||||||
|
// 兩個 endpoint 同時存在是**刻意為之**:
|
||||||
|
// - JSON 版對於 api-server 一次性 JSON request/response(例如 GET /healthz)較好寫、好測
|
||||||
|
// - Raw 版是 `session.ProxyClient.OpenStream(ctx) net.Conn` 語意的真實底層
|
||||||
|
// (api-server 端會拿這條 hijacked 連線當 net.Conn 直接 `r.Write(conn)` + `http.ReadResponse(conn)`)
|
||||||
|
//
|
||||||
|
// 協議(API server 端怎麼用):
|
||||||
|
// 1. POST /internal/forward/raw?token=<session_token>
|
||||||
|
// (可不帶 body;hijack 在收到 request 後立刻做)
|
||||||
|
// 2. remote-proxy 找到 session → 寫回 `HTTP/1.1 200 Connected\r\n\r\n` 代表「session ready」
|
||||||
|
// → Hijack 自己的連線 → 從 yamux 開一個 stream → 雙向 io.Copy
|
||||||
|
// 3. API server 端拿到連線後,依照 HTTP 協定把完整 request 丟進去,local agent 回的 response
|
||||||
|
// bytes 會原封不動從同條連線讀回來;保留 chunked / streaming / WS upgrade 語意
|
||||||
|
//
|
||||||
|
// 雛形範例(api-server 端,B4 會實作):
|
||||||
|
//
|
||||||
|
// dial raw to /internal/forward/raw?token=xxx
|
||||||
|
// → 讀一行 "HTTP/1.1 200 Connected" + 空行
|
||||||
|
// → 拿下面那條 net.Conn:
|
||||||
|
// - r.Write(conn) // 送出 HTTP request
|
||||||
|
// - resp, _ := http.ReadResponse(bufio.NewReader(conn), r)
|
||||||
|
// - io.Copy(clientResponseWriter, resp.Body) // streaming 友善
|
||||||
|
//
|
||||||
|
// 失敗處理:
|
||||||
|
// - session 不存在 → 502 JSON(在 hijack 之前回 statusline + body)
|
||||||
|
// - hijack 不支援 → 500 JSON
|
||||||
|
// - OpenStream 失敗 → hijack 後寫回 `HTTP/1.1 502 Bad Gateway\r\n\r\n<body>` 再關閉
|
||||||
|
func (s *InternalServer) HandleForwardRaw(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
writeJSONError(w, http.StatusMethodNotAllowed, "METHOD_NOT_ALLOWED", "POST required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.URL.Query().Get("token")
|
||||||
|
if token == "" {
|
||||||
|
writeJSONError(w, http.StatusBadRequest, "MISSING_SESSION_TOKEN", "token query param required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. 先查 session — 若不存在,直接用一般 JSON error 回,尚未 hijack
|
||||||
|
handle, err := s.store.Lookup(r.Context(), token)
|
||||||
|
if err != nil || handle == nil || handle.IsClosed() {
|
||||||
|
s.logger.Warn("raw forward: session not found or closed",
|
||||||
|
"token_prefix", tokenPrefix(token),
|
||||||
|
"error", err)
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_DISCONNECTED", "session not connected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Hijack — 把連線從 http.Server 接管成 raw TCP
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "HIJACK_UNSUPPORTED", "hijacking not supported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, _, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("raw forward: hijack failed",
|
||||||
|
"error", err,
|
||||||
|
"token_prefix", tokenPrefix(token))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// 3. 通知 caller「session 已 ready」— 用最小 HTTP/1.1 回應行
|
||||||
|
// 這是 Connect-style 的慣例(類似 HTTP CONNECT tunneling)
|
||||||
|
// 必須在 hijack 之後自己寫,因為 http.ResponseWriter 已失效
|
||||||
|
if _, werr := clientConn.Write([]byte("HTTP/1.1 200 Connected\r\n\r\n")); werr != nil {
|
||||||
|
s.logger.Warn("raw forward: write connected line failed",
|
||||||
|
"error", werr,
|
||||||
|
"token_prefix", tokenPrefix(token))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 從 session 開 yamux stream
|
||||||
|
stream, err := handle.OpenStream(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("raw forward: open stream failed",
|
||||||
|
"error", err,
|
||||||
|
"token_prefix", tokenPrefix(token))
|
||||||
|
// Hijack 後還能寫原 bytes — 回一個 HTTP 502 幫助 caller debug
|
||||||
|
_, _ = clientConn.Write([]byte(
|
||||||
|
"HTTP/1.1 502 Bad Gateway\r\n" +
|
||||||
|
"Content-Type: application/json\r\n" +
|
||||||
|
"Connection: close\r\n\r\n" +
|
||||||
|
`{"error":{"code":"TUNNEL_ERROR","message":"open stream failed"}}`,
|
||||||
|
))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
s.logger.Info("raw forward: stream opened",
|
||||||
|
"token_prefix", tokenPrefix(token),
|
||||||
|
"remote_addr", r.RemoteAddr)
|
||||||
|
|
||||||
|
// 5. 雙向 pipe — 把接管的連線和 yamux stream 連起來
|
||||||
|
// clientConn <---> stream (raw bytes,不做任何 HTTP 解析)
|
||||||
|
// 任一方向 EOF / error 就關閉另一邊,確保兩個 goroutine 都會退出
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// 從 caller(api-server)讀 → 寫到 tunnel stream
|
||||||
|
_, _ = io.Copy(stream, clientConn)
|
||||||
|
// 關 stream 的寫入端讓另一邊的 Copy 收到 EOF;yamux stream 沒有
|
||||||
|
// CloseWrite,直接 Close 整條 stream 讓另一側也 EOF
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
// 從 tunnel stream 讀 → 寫回 caller
|
||||||
|
_, _ = io.Copy(clientConn, stream)
|
||||||
|
_ = clientConn.Close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
s.logger.Info("raw forward: stream closed",
|
||||||
|
"token_prefix", tokenPrefix(token))
|
||||||
|
}
|
||||||
90
visionA-backend/internal/relay/local_handle.go
Normal file
90
visionA-backend/internal/relay/local_handle.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LocalHandle 是 remote-proxy 端的 session.Handle 實作,
|
||||||
|
// 直接包住一個 yamux.Session(真實持有 tunnel 連線的地方)。
|
||||||
|
//
|
||||||
|
// 為修 B2 Review M1(Heartbeat vs CleanupExpired race),
|
||||||
|
// LocalHandle 以 mu 保護 summary 的 LastHeartbeat 讀寫;
|
||||||
|
// Summary() 回傳 snapshot(副本)、RecordHeartbeat 在 lock 下寫入。
|
||||||
|
type LocalHandle struct {
|
||||||
|
yamuxSession *yamux.Session
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
summary session.Summary
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalHandle 建立一個 LocalHandle。
|
||||||
|
//
|
||||||
|
// token / remoteAddr 由 relay server 在 handleTunnelConnect 時傳入。
|
||||||
|
// ConnectedAt 與 LastHeartbeat 初始為 time.Now().UTC()。
|
||||||
|
func NewLocalHandle(yamuxSession *yamux.Session, token, remoteAddr string) *LocalHandle {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return &LocalHandle{
|
||||||
|
yamuxSession: yamuxSession,
|
||||||
|
summary: session.Summary{
|
||||||
|
Token: token,
|
||||||
|
ConnectedAt: now,
|
||||||
|
LastHeartbeat: now,
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream 在 yamux session 上開一條新 stream。
|
||||||
|
// 若 session 已關閉回 ErrSessionClosed。
|
||||||
|
func (h *LocalHandle) OpenStream(ctx context.Context) (net.Conn, error) {
|
||||||
|
if h.yamuxSession.IsClosed() {
|
||||||
|
return nil, session.ErrSessionClosed
|
||||||
|
}
|
||||||
|
// yamux.Session.Open() 不接受 context;若 ctx 已取消應盡量早退。
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
stream, err := h.yamuxSession.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 關閉底層 yamux session(會同時關閉底下的 WebSocket)。
|
||||||
|
func (h *LocalHandle) Close() error {
|
||||||
|
return h.yamuxSession.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed 回報 yamux session 是否已關閉。
|
||||||
|
func (h *LocalHandle) IsClosed() bool {
|
||||||
|
return h.yamuxSession.IsClosed()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary 回傳 session 的資訊快照。
|
||||||
|
//
|
||||||
|
// 回傳的是 summary 的複本,防止 caller 觀察到中間態,
|
||||||
|
// 也避免 Store.List 與 RecordHeartbeat 對同一 pointer 的並發寫入。
|
||||||
|
func (h *LocalHandle) Summary() *session.Summary {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
cp := h.summary
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordHeartbeat 更新 LastHeartbeat;由 Store.Heartbeat 呼叫。
|
||||||
|
func (h *LocalHandle) RecordHeartbeat(t time.Time) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.summary.LastHeartbeat = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:LocalHandle 必須實作 session.Handle。
|
||||||
|
var _ session.Handle = (*LocalHandle)(nil)
|
||||||
462
visionA-backend/internal/relay/server.go
Normal file
462
visionA-backend/internal/relay/server.go
Normal file
@ -0,0 +1,462 @@
|
|||||||
|
// Package relay 實作 remote-proxy 端的 tunnel server 與通用代理轉發。
|
||||||
|
//
|
||||||
|
// 核心職責:
|
||||||
|
// - 接受 local agent 的 `/tunnel/connect` WebSocket upgrade,建立 yamux session
|
||||||
|
// - 把 session 註冊到 session.Store(由 remote-proxy 唯一持有)
|
||||||
|
// - 提供通用代理 `handleProxy`:依 token 找到 session → open stream → 轉發 HTTP/WS
|
||||||
|
// - 提供 `/relay/status` 簡易連線狀態查詢
|
||||||
|
//
|
||||||
|
// 從 POC `edge-ai-platform/server/internal/relay/server.go` 複製後改造:
|
||||||
|
// 1. Session map → session.Store interface(由外部注入)
|
||||||
|
// 2. yamux KeepAliveInterval:POC 的 30s → 10s(對齊 tunnel.md §4.2 M-5)
|
||||||
|
// 3. Token 格式驗證(vAs_ + 64 hex 或 vAc_ + 32 hex 雛形可交替)
|
||||||
|
// 4. 使用結構化 JSON logger(log/slog),不再用 `log.Printf`
|
||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
|
||||||
|
"visiona-backend/internal/auth"
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/wsconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DefaultKeepAliveInterval 是 yamux 的心跳間隔(對齊 tunnel.md §4.2 M-5)。
|
||||||
|
// 連續 3 次未收到 pong(= 30s)即判定掉線。
|
||||||
|
const DefaultKeepAliveInterval = 10 * time.Second
|
||||||
|
|
||||||
|
// DefaultConnectionWriteTimeout 是單一 yamux 寫入的最大等待時間。
|
||||||
|
const DefaultConnectionWriteTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// Options 提供 NewServer 可選設定。
|
||||||
|
type Options struct {
|
||||||
|
// KeepAliveInterval:yamux keep-alive 心跳;0 → 採用 DefaultKeepAliveInterval。
|
||||||
|
KeepAliveInterval time.Duration
|
||||||
|
|
||||||
|
// ConnectionWriteTimeout:yamux 寫入 timeout;0 → 採用 DefaultConnectionWriteTimeout。
|
||||||
|
ConnectionWriteTimeout time.Duration
|
||||||
|
|
||||||
|
// AllowedOrigins:WebSocket upgrade 的 Origin 白名單;
|
||||||
|
// 空 slice → 接受任意 Origin(對齊 tunnel.md §4.1;local agent 非瀏覽器無 origin 風險)。
|
||||||
|
AllowedOrigins []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// defaultOptions 建立含預設值的 Options。
|
||||||
|
func defaultOptions() Options {
|
||||||
|
return Options{
|
||||||
|
KeepAliveInterval: DefaultKeepAliveInterval,
|
||||||
|
ConnectionWriteTimeout: DefaultConnectionWriteTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Server 是 remote-proxy 端的 tunnel relay server。
|
||||||
|
//
|
||||||
|
// 與 POC 的差異:
|
||||||
|
// - 不自己維護 session map;全部委託 session.Store
|
||||||
|
// - 允許注入 logger,行為與生產環境一致
|
||||||
|
type Server struct {
|
||||||
|
store session.Store
|
||||||
|
logger *slog.Logger
|
||||||
|
opts Options
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
shutdown bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewServer 建立一個 relay.Server。
|
||||||
|
//
|
||||||
|
// store:session.Store 實作(remote-proxy 端通常為 *session.InMemoryStore)。
|
||||||
|
// logger:結構化 logger;nil 時使用 slog.Default。
|
||||||
|
func NewServer(store session.Store, logger *slog.Logger, opts ...Options) *Server {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
o := defaultOptions()
|
||||||
|
if len(opts) > 0 {
|
||||||
|
// 覆寫非零欄位
|
||||||
|
if opts[0].KeepAliveInterval > 0 {
|
||||||
|
o.KeepAliveInterval = opts[0].KeepAliveInterval
|
||||||
|
}
|
||||||
|
if opts[0].ConnectionWriteTimeout > 0 {
|
||||||
|
o.ConnectionWriteTimeout = opts[0].ConnectionWriteTimeout
|
||||||
|
}
|
||||||
|
o.AllowedOrigins = opts[0].AllowedOrigins
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Server{
|
||||||
|
store: store,
|
||||||
|
logger: logger,
|
||||||
|
opts: o,
|
||||||
|
upgrader: websocket.Upgrader{
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
// local agent 不跑在瀏覽器;預設放行任意 Origin。
|
||||||
|
// Phase 1 若有需要可對 AllowedOrigins 做比對。
|
||||||
|
if len(o.AllowedOrigins) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
for _, a := range o.AllowedOrigins {
|
||||||
|
if strings.EqualFold(a, origin) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// yamuxConfig 建構 yamux.Config — 套用我們統一的 10s keepalive。
|
||||||
|
func (s *Server) yamuxConfig(logOutput io.Writer) *yamux.Config {
|
||||||
|
cfg := yamux.DefaultConfig()
|
||||||
|
cfg.EnableKeepAlive = true
|
||||||
|
cfg.KeepAliveInterval = s.opts.KeepAliveInterval
|
||||||
|
cfg.ConnectionWriteTimeout = s.opts.ConnectionWriteTimeout
|
||||||
|
if logOutput != nil {
|
||||||
|
cfg.LogOutput = logOutput
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleTunnelConnect 處理 local agent 的 WebSocket upgrade 請求。
|
||||||
|
//
|
||||||
|
// Route: `GET /tunnel/connect?token=<token>`(亦接受 `X-Relay-Token` header)。
|
||||||
|
// 流程:
|
||||||
|
// 1. 取出 + 驗證 token 格式(vAs_ / vAc_)
|
||||||
|
// 2. WebSocket upgrade
|
||||||
|
// 3. 包成 net.Conn → yamux.Server
|
||||||
|
// 4. 建 LocalHandle + store.Register(舊 session 會自動被 Close)
|
||||||
|
// 5. 阻塞於 session.CloseChan;斷線後 Unregister
|
||||||
|
func (s *Server) HandleTunnelConnect(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// B3 Review Minor #2 修補:shutdown 期間拒絕新的 tunnel upgrade,
|
||||||
|
// 避免 graceful shutdown 過程中又有新 session 註冊進來。
|
||||||
|
s.mu.Lock()
|
||||||
|
isShutdown := s.shutdown
|
||||||
|
s.mu.Unlock()
|
||||||
|
if isShutdown {
|
||||||
|
writeJSONError(w, http.StatusServiceUnavailable, "SHUTTING_DOWN", "server is shutting down")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tok := getToken(r)
|
||||||
|
if tok == "" {
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "NO_TOKEN", "token required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isAcceptableToken(tok) {
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "INVALID_TOKEN_FORMAT", "token format invalid")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
wsConn, err := s.upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
// Upgrader 失敗時已經寫了 HTTP 回應
|
||||||
|
s.logger.Warn("tunnel upgrade failed",
|
||||||
|
"error", err,
|
||||||
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"token_prefix", tokenPrefix(tok))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
netConn := wsconn.New(wsConn)
|
||||||
|
|
||||||
|
ymCfg := s.yamuxConfig(nil)
|
||||||
|
ym, err := yamux.Server(netConn, ymCfg)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Error("yamux server creation failed",
|
||||||
|
"error", err,
|
||||||
|
"token_prefix", tokenPrefix(tok))
|
||||||
|
_ = wsConn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
handle := NewLocalHandle(ym, tok, r.RemoteAddr)
|
||||||
|
|
||||||
|
// Register 會 Close 同 token 舊連線(Q5 裁決:後連覆蓋前連)
|
||||||
|
if err := s.store.Register(r.Context(), tok, handle); err != nil {
|
||||||
|
s.logger.Error("session register failed",
|
||||||
|
"error", err,
|
||||||
|
"token_prefix", tokenPrefix(tok))
|
||||||
|
_ = ym.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("tunnel connected",
|
||||||
|
"token_prefix", tokenPrefix(tok),
|
||||||
|
"remote_addr", r.RemoteAddr,
|
||||||
|
"keepalive_interval", s.opts.KeepAliveInterval.String())
|
||||||
|
|
||||||
|
// 阻塞到 yamux session 關閉
|
||||||
|
<-ym.CloseChan()
|
||||||
|
|
||||||
|
// 只有在「當前還是這個 handle」時才移除,避免覆蓋後的舊流程意外刪了新的。
|
||||||
|
if cur, lookupErr := s.store.Lookup(r.Context(), tok); lookupErr == nil {
|
||||||
|
if cur == handle {
|
||||||
|
_ = s.store.Unregister(r.Context(), tok)
|
||||||
|
}
|
||||||
|
} else if errors.Is(lookupErr, session.ErrSessionNotFound) {
|
||||||
|
// 已被清掉或已被新連線取代,無動作
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("tunnel disconnected",
|
||||||
|
"token_prefix", tokenPrefix(tok),
|
||||||
|
"remote_addr", r.RemoteAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleRelayStatus 回報指定 token 的連線狀態(debug / health 用)。
|
||||||
|
//
|
||||||
|
// Route: `GET /relay/status?token=<token>`(或無 token → 全體線上數量)。
|
||||||
|
func (s *Server) HandleRelayStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
tok := getToken(r)
|
||||||
|
if tok != "" {
|
||||||
|
h, err := s.store.Lookup(r.Context(), tok)
|
||||||
|
if err != nil {
|
||||||
|
// 不存在 → online=false
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"online": false,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sum := h.Summary()
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"online": !h.IsClosed(),
|
||||||
|
"connected_at": sum.ConnectedAt,
|
||||||
|
"last_heartbeat": sum.LastHeartbeat,
|
||||||
|
"remote_addr": sum.RemoteAddr,
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
summaries, err := s.store.List(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
writeJSONError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"online": len(summaries) > 0,
|
||||||
|
"tunnel_count": len(summaries),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleProxy 是通用的 HTTP 反向代理 handler。
|
||||||
|
//
|
||||||
|
// 主要給 debug / 舊相容路徑用(POC 原本讓瀏覽器直接打 proxy);
|
||||||
|
// 雛形正式流量走 `/internal/forward/http`(見 remote-proxy 的 main.go)。
|
||||||
|
//
|
||||||
|
// Route: `Any /*`(或自行綁在其他 path)。
|
||||||
|
func (s *Server) HandleProxy(w http.ResponseWriter, r *http.Request) {
|
||||||
|
tok := getToken(r)
|
||||||
|
if tok == "" {
|
||||||
|
writeJSONError(w, http.StatusUnauthorized, "NO_TOKEN", "X-Relay-Token header or token query param required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h, err := s.store.Lookup(r.Context(), tok)
|
||||||
|
if err != nil || h.IsClosed() {
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_DISCONNECTED", "edge agent is not connected")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := h.OpenStream(r.Context())
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("open stream failed", "error", err, "token_prefix", tokenPrefix(tok))
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_ERROR", "failed to open tunnel stream")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer stream.Close()
|
||||||
|
|
||||||
|
// 不把 token / internal header 轉給 local agent
|
||||||
|
r.Header.Del("X-Relay-Token")
|
||||||
|
|
||||||
|
if isWebSocketUpgrade(r) {
|
||||||
|
s.proxyWebSocket(w, r, stream)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.Write(stream); err != nil {
|
||||||
|
s.logger.Warn("write request to tunnel failed", "error", err)
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_WRITE_ERROR", "failed to write request to tunnel")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("read response from tunnel failed", "error", err)
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_READ_ERROR", "failed to read response from tunnel")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
for key, vals := range resp.Header {
|
||||||
|
for _, v := range vals {
|
||||||
|
w.Header().Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
|
||||||
|
// 串流支援(MJPEG / SSE):有 Flusher 就每塊 flush 一次
|
||||||
|
if flusher, ok := w.(http.Flusher); ok {
|
||||||
|
buf := make([]byte, 32*1024)
|
||||||
|
for {
|
||||||
|
n, rerr := resp.Body.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
if _, werr := w.Write(buf[:n]); werr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
if rerr != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, _ = io.Copy(w, resp.Body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// proxyWebSocket 處理瀏覽器(或 api-server 內部)的 WebSocket upgrade:
|
||||||
|
// 把 upgrade request 透過 yamux stream 送到 local agent,
|
||||||
|
// 再 Hijack 當前連線做雙向 pipe(POC 原始邏輯)。
|
||||||
|
func (s *Server) proxyWebSocket(w http.ResponseWriter, r *http.Request, stream net.Conn) {
|
||||||
|
if err := r.Write(stream); err != nil {
|
||||||
|
s.logger.Warn("ws: write upgrade request failed", "error", err)
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_WRITE_ERROR", "failed to write upgrade request to tunnel")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(stream), r)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("ws: read upgrade response failed", "error", err)
|
||||||
|
writeJSONError(w, http.StatusBadGateway, "TUNNEL_READ_ERROR", "failed to read upgrade response from tunnel")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusSwitchingProtocols {
|
||||||
|
for key, vals := range resp.Header {
|
||||||
|
for _, v := range vals {
|
||||||
|
w.Header().Add(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w.WriteHeader(resp.StatusCode)
|
||||||
|
_, _ = io.Copy(w, resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientConn, clientBuf, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("ws: hijack failed", "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// 把 101 回傳給 caller
|
||||||
|
_ = resp.Write(clientBuf)
|
||||||
|
_ = clientBuf.Flush()
|
||||||
|
|
||||||
|
// 雙向 pipe
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(stream, clientConn)
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _ = io.Copy(clientConn, stream)
|
||||||
|
_ = clientConn.Close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shutdown 關閉 Server 所管理的所有 session(通常在程序結束時呼叫)。
|
||||||
|
//
|
||||||
|
// Store 不需要在此被關閉(Store 由 caller 注入);只通知:不再接受新 tunnel。
|
||||||
|
func (s *Server) Shutdown() {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.shutdown = true
|
||||||
|
s.mu.Unlock()
|
||||||
|
// 實際的 session close 會由 CleanupExpired / cmd/remote-proxy 主迴圈處理。
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// getToken 從 `X-Relay-Token` header 或 `token` query 參數取出 token。
|
||||||
|
// Header 優先於 query,行為與 POC 一致。
|
||||||
|
func getToken(r *http.Request) string {
|
||||||
|
if tok := r.Header.Get("X-Relay-Token"); tok != "" {
|
||||||
|
return tok
|
||||||
|
}
|
||||||
|
return r.URL.Query().Get("token")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isAcceptableToken 檢查 token 是否為 pairing 或 session 任一合法格式。
|
||||||
|
//
|
||||||
|
// 雛形階段 local agent 仍用 pairing token 連線(見 tunnel.md §2.2.1);
|
||||||
|
// Phase 1 升級兩階段 token 後仍然是 session token 為主。此處兩者皆接受。
|
||||||
|
func isAcceptableToken(tok string) bool {
|
||||||
|
return auth.IsValidPairingToken(tok) || auth.IsValidSessionToken(tok)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenPrefix 回傳 token 的前 8 字元(log 用,避免 log 完整 token)。
|
||||||
|
func tokenPrefix(tok string) string {
|
||||||
|
if len(tok) <= 8 {
|
||||||
|
return tok
|
||||||
|
}
|
||||||
|
return tok[:8]
|
||||||
|
}
|
||||||
|
|
||||||
|
// isWebSocketUpgrade 判斷 request 是否為 WebSocket upgrade。
|
||||||
|
//
|
||||||
|
// B3 Review Minor #4 修補:同時檢查 Upgrade 與 Connection header,
|
||||||
|
// 避免只看 Upgrade 在極端 case(curl 手工送單一 header)時誤判。
|
||||||
|
// RFC 6455 §4.1 規定合法的 WS upgrade 要同時包含兩個 header。
|
||||||
|
func isWebSocketUpgrade(r *http.Request) bool {
|
||||||
|
if !strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Connection header 可能是 "upgrade" 或 "keep-alive, Upgrade" 等組合,
|
||||||
|
// 用 Contains 不區分大小寫即可。
|
||||||
|
return strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeJSONError 寫回統一格式的 JSON error(對齊 API error schema)。
|
||||||
|
func writeJSONError(w http.ResponseWriter, status int, code, message string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"error": map[string]any{
|
||||||
|
"code": code,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatAddr 把 port 格式化為 ":{port}",供 http.Server.Addr 使用。
|
||||||
|
func FormatAddr(port int) string {
|
||||||
|
return fmt.Sprintf(":%d", port)
|
||||||
|
}
|
||||||
365
visionA-backend/internal/relay/server_test.go
Normal file
365
visionA-backend/internal/relay/server_test.go
Normal file
@ -0,0 +1,365 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/hashicorp/yamux"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"visiona-backend/internal/session"
|
||||||
|
"visiona-backend/internal/wsconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testPairingToken 是一個格式合法的 pairing token,用於測試。
|
||||||
|
const testPairingToken = "vAc_0123456789abcdef0123456789abcdef"
|
||||||
|
|
||||||
|
// startFakeLocalAgent 啟動一個「假 local agent」:
|
||||||
|
// - 對指定 relay URL 開 WebSocket
|
||||||
|
// - 在 WS 上建立 yamux Client
|
||||||
|
// - 對每一個 stream 做 http.ReadRequest → 回傳 handler 提供的 response
|
||||||
|
//
|
||||||
|
// 這模擬 POC edge-ai-server 的 tunnel client 角色,用於驗證 relay forwarding 路徑。
|
||||||
|
//
|
||||||
|
// handler 的 http.Handler 對應「local server(127.0.0.1:3721)」;此函式會在
|
||||||
|
// tunnel stream 之上直接用 http.ReadRequest 把請求轉給 handler 並寫回 response。
|
||||||
|
func startFakeLocalAgent(t *testing.T, relayURL string, token string, handler http.Handler) (stop func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
u, err := url.Parse(relayURL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("token", token)
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
rawWS, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
netConn := wsconn.New(rawWS)
|
||||||
|
ym, err := yamux.Client(netConn, yamux.DefaultConfig())
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
for {
|
||||||
|
stream, aerr := ym.Accept()
|
||||||
|
if aerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func(s net.Conn) {
|
||||||
|
defer s.Close()
|
||||||
|
req, rerr := http.ReadRequest(bufio.NewReader(s))
|
||||||
|
if rerr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// handler 需要一個 ResponseWriter 能寫回 raw stream;
|
||||||
|
// 用 httptest.NewRecorder 收集 response 再自己寫回。
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rec, req)
|
||||||
|
result := rec.Result()
|
||||||
|
defer result.Body.Close()
|
||||||
|
_ = result.Write(s)
|
||||||
|
}(stream)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return func() {
|
||||||
|
_ = ym.Close()
|
||||||
|
_ = rawWS.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 以 stdlib net.Conn alias(避免再 import 一次)。
|
||||||
|
// yamux.Client.Accept() 回傳 net.Conn,此 alias 只為測試可讀性。
|
||||||
|
// 注意:這裡沒有實際 type 定義,直接使用 stdlib 的 net.Conn。
|
||||||
|
|
||||||
|
// TestServer_TunnelConnect_RejectsMissingToken 驗證沒帶 token 的 upgrade 會被拒。
|
||||||
|
func TestServer_TunnelConnect_RejectsMissingToken(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default())
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL + "/tunnel/connect")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_TunnelConnect_RejectsInvalidTokenFormat 驗證 token 格式錯誤會被拒。
|
||||||
|
func TestServer_TunnelConnect_RejectsInvalidTokenFormat(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default())
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL + "/tunnel/connect?token=garbage")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_TunnelConnect_RegistersAndUnregisters 驗證:
|
||||||
|
// - 合法 token → upgrade 成功 → session 註冊進 store
|
||||||
|
// - local agent 斷開 → session 從 store 移除
|
||||||
|
func TestServer_TunnelConnect_RegistersAndUnregisters(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/tunnel/connect"
|
||||||
|
stop := startFakeLocalAgent(t, wsURL, testPairingToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// 等 register 完成
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 斷線
|
||||||
|
stop()
|
||||||
|
|
||||||
|
// 等 unregister
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return !ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_HandleProxy_ForwardsRequest 驗證:
|
||||||
|
// - 透過 session store 找到 handle
|
||||||
|
// - OpenStream + 轉發 HTTP request
|
||||||
|
// - local agent 回的 response 可寫回 caller
|
||||||
|
func TestServer_HandleProxy_ForwardsRequest(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
mux.HandleFunc("/proxy/", srv.HandleProxy)
|
||||||
|
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
// fake local agent:回 JSON {"ok": true, "path": <收到的 path>}
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/tunnel/connect"
|
||||||
|
stop := startFakeLocalAgent(t, wsURL, testPairingToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(`{"ok":true,"path":"` + r.URL.Path + `"}`))
|
||||||
|
}))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 透過 HandleProxy 轉發
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, ts.URL+"/proxy/api/devices", nil)
|
||||||
|
req.Header.Set("X-Relay-Token", testPairingToken)
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
assert.Contains(t, string(body), `"ok":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_HandleProxy_NoTunnel 當指定 token 沒 session 時,回 502。
|
||||||
|
func TestServer_HandleProxy_NoTunnel(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default())
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/proxy/", srv.HandleProxy)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, ts.URL+"/proxy/api/anything", nil)
|
||||||
|
req.Header.Set("X-Relay-Token", testPairingToken)
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusBadGateway, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServer_HandleRelayStatus_ReportsOnline 驗證 /relay/status?token=... 能報告連線狀態。
|
||||||
|
func TestServer_HandleRelayStatus_ReportsOnline(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
mux.HandleFunc("/relay/status", srv.HandleRelayStatus)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(ts.URL, "http") + "/tunnel/connect"
|
||||||
|
stop := startFakeLocalAgent(t, wsURL, testPairingToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL + "/relay/status?token=" + testPairingToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, true, body["online"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInternalServer_ForwardHTTP 驗證 internal forward JSON API 可以轉發 HTTP 請求。
|
||||||
|
// 這是 api-server → remote-proxy 的 Phase 0 關鍵路徑。
|
||||||
|
func TestInternalServer_ForwardHTTP(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
internal := NewInternalServer(store, slog.Default())
|
||||||
|
|
||||||
|
// Tunnel server
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
tunnelSrv := httptest.NewServer(tunnelMux)
|
||||||
|
defer tunnelSrv.Close()
|
||||||
|
|
||||||
|
// Internal server
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internal.Routes(internalMux)
|
||||||
|
internalSrv := httptest.NewServer(internalMux)
|
||||||
|
defer internalSrv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(tunnelSrv.URL, "http") + "/tunnel/connect"
|
||||||
|
stop := startFakeLocalAgent(t, wsURL, testPairingToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("X-Request-Path", r.URL.Path)
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.WriteString(w, `{"forwarded":true}`)
|
||||||
|
}))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
// 打 internal forward
|
||||||
|
payload := ForwardHTTPRequest{
|
||||||
|
SessionToken: testPairingToken,
|
||||||
|
Method: http.MethodGet,
|
||||||
|
Path: "/api/devices",
|
||||||
|
Headers: map[string]string{"X-Test": "1"},
|
||||||
|
}
|
||||||
|
bb, _ := json.Marshal(payload)
|
||||||
|
resp, err := http.Post(internalSrv.URL+"/internal/forward/http",
|
||||||
|
"application/json", bytes.NewReader(bb))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
|
||||||
|
var fr ForwardHTTPResponse
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&fr))
|
||||||
|
require.Nil(t, fr.Error, "error: %+v", fr.Error)
|
||||||
|
assert.Equal(t, http.StatusOK, fr.Status)
|
||||||
|
|
||||||
|
decoded, err := base64.StdEncoding.DecodeString(fr.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, string(decoded), `"forwarded":true`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInternalServer_GetSession 驗證 GET /internal/session/:token 能回傳 session 摘要。
|
||||||
|
func TestInternalServer_GetSession(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
srv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond})
|
||||||
|
internal := NewInternalServer(store, slog.Default())
|
||||||
|
|
||||||
|
tunnelMux := http.NewServeMux()
|
||||||
|
tunnelMux.HandleFunc("/tunnel/connect", srv.HandleTunnelConnect)
|
||||||
|
tunnelSrv := httptest.NewServer(tunnelMux)
|
||||||
|
defer tunnelSrv.Close()
|
||||||
|
|
||||||
|
internalMux := http.NewServeMux()
|
||||||
|
internal.Routes(internalMux)
|
||||||
|
internalSrv := httptest.NewServer(internalMux)
|
||||||
|
defer internalSrv.Close()
|
||||||
|
|
||||||
|
wsURL := "ws" + strings.TrimPrefix(tunnelSrv.URL, "http") + "/tunnel/connect"
|
||||||
|
stop := startFakeLocalAgent(t, wsURL, testPairingToken, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
ok, _ := store.Exists(context.Background(), testPairingToken)
|
||||||
|
return ok
|
||||||
|
}, 2*time.Second, 20*time.Millisecond)
|
||||||
|
|
||||||
|
resp, err := http.Get(internalSrv.URL + "/internal/session/" + testPairingToken)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||||
|
var body map[string]any
|
||||||
|
require.NoError(t, json.NewDecoder(resp.Body).Decode(&body))
|
||||||
|
assert.Equal(t, testPairingToken, body["token"])
|
||||||
|
assert.Equal(t, true, body["connected"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInternalServer_GetSession_NotFound
|
||||||
|
func TestInternalServer_GetSession_NotFound(t *testing.T) {
|
||||||
|
store := session.NewInMemoryStore()
|
||||||
|
internal := NewInternalServer(store, slog.Default())
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
internal.Routes(mux)
|
||||||
|
ts := httptest.NewServer(mux)
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(ts.URL + "/internal/session/vAc_ffffffffffffffffffffffffffffffff")
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
assert.Equal(t, http.StatusNotFound, resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenHelpers 驗證小工具函式。
|
||||||
|
func TestTokenHelpers(t *testing.T) {
|
||||||
|
assert.True(t, isAcceptableToken(testPairingToken))
|
||||||
|
assert.False(t, isAcceptableToken("not-a-token"))
|
||||||
|
assert.Equal(t, "vAc_0123", tokenPrefix(testPairingToken))
|
||||||
|
assert.Equal(t, "short", tokenPrefix("short"))
|
||||||
|
}
|
||||||
|
|
||||||
332
visionA-backend/internal/session/forwarder.go
Normal file
332
visionA-backend/internal/session/forwarder.go
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
// forwarder.go — api-server → remote-proxy 的 raw forward client。
|
||||||
|
//
|
||||||
|
// 這是雛形雙 binary 架構下「api-server 把前端 HTTP 請求轉發到 local agent」
|
||||||
|
// 的核心元件。
|
||||||
|
//
|
||||||
|
// 整條路徑:
|
||||||
|
//
|
||||||
|
// browser ─HTTP─► api-server handler
|
||||||
|
// │
|
||||||
|
// │ Forwarder.ForwardHTTP / OpenStream
|
||||||
|
// ▼
|
||||||
|
// raw TCP dial remote-proxy: POST /internal/forward/raw?token=...
|
||||||
|
// │ (B3 Major-1 修復後新增的 hijack endpoint)
|
||||||
|
// ▼
|
||||||
|
// remote-proxy hijack 自己的連線 → yamux.OpenStream → 雙向 io.Copy
|
||||||
|
// │
|
||||||
|
// ▼
|
||||||
|
// local agent (yamux client) 把 stream 上的 HTTP request
|
||||||
|
// ▼
|
||||||
|
// 轉到本地 127.0.0.1:3721(local-tool)回 response
|
||||||
|
//
|
||||||
|
// 對齊 `.autoflow/04-architecture/api/api-internal.md` §POST /internal/forward/raw
|
||||||
|
// 與 `.autoflow/04-architecture/tunnel.md` §3.3。
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultDialTimeout 是 raw TCP dial remote-proxy 的最大等待時間。
|
||||||
|
const defaultDialTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// defaultHandshakeTimeout 是讀取「HTTP/1.1 200 Connected」握手的最大等待時間。
|
||||||
|
const defaultHandshakeTimeout = 10 * time.Second
|
||||||
|
|
||||||
|
// Forwarder 把 api-server 的 HTTP 請求 forward 到 remote-proxy。
|
||||||
|
//
|
||||||
|
// 並發安全:本 struct 的方法不共享可變狀態,每個 OpenStream 走獨立 net.Conn;
|
||||||
|
// 多個 goroutine 可同時呼叫。
|
||||||
|
type Forwarder struct {
|
||||||
|
// proxyHost 是從 baseURL 解析出來的 host:port,供 net.Dial 用。
|
||||||
|
proxyHost string
|
||||||
|
|
||||||
|
// dialer 用於 raw TCP dial。獨立成欄位以利測試 / 未來換成 TLS dial。
|
||||||
|
dialer net.Dialer
|
||||||
|
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewForwarder 從 baseURL(例:http://localhost:3801)建立 Forwarder。
|
||||||
|
//
|
||||||
|
// baseURL 必須是 http:// 或 https:// 開頭;其他 scheme 視為錯誤但延遲到
|
||||||
|
// 第一次呼叫時才回(保持建構簽章簡單)。
|
||||||
|
//
|
||||||
|
// **注意**:雛形 internal port 是純 HTTP(network policy 阻擋外部存取,見
|
||||||
|
// api-internal.md §安全)。Phase 1 加 mTLS 時,本 Forwarder 需擴充支援 TLS。
|
||||||
|
func NewForwarder(baseURL string, logger *slog.Logger) *Forwarder {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
host := parseHostFromBaseURL(baseURL)
|
||||||
|
return &Forwarder{
|
||||||
|
proxyHost: host,
|
||||||
|
dialer: net.Dialer{Timeout: defaultDialTimeout},
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseHostFromBaseURL 從 baseURL 取出 host:port,失敗時回傳空字串
|
||||||
|
// (後續 OpenStream 會拒絕並回明確錯誤)。
|
||||||
|
func parseHostFromBaseURL(baseURL string) string {
|
||||||
|
if baseURL == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
u, err := url.Parse(baseURL)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return u.Host
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream 對 remote-proxy 開一條 raw TCP 連線,完成 hijack 握手,並回傳
|
||||||
|
// 一條可以直接用 net.Conn 語意操作的連線(底層是 yamux stream)。
|
||||||
|
//
|
||||||
|
// 用法(典型 api-server handler):
|
||||||
|
//
|
||||||
|
// conn, err := forwarder.OpenStream(ctx, sessionToken)
|
||||||
|
// if err != nil { ... }
|
||||||
|
// defer conn.Close()
|
||||||
|
//
|
||||||
|
// httpReq.Write(conn) // 送 HTTP request
|
||||||
|
// resp, _ := http.ReadResponse(bufio.NewReader(conn), httpReq)
|
||||||
|
// io.Copy(browserResponseWriter, resp.Body) // streaming friendly
|
||||||
|
//
|
||||||
|
// 失敗回傳的 error:
|
||||||
|
// - ErrSessionNotFound:remote-proxy 在 hijack 前回 502 JSON
|
||||||
|
// - 其他 wrapped error:dial / 握手 / 解析錯誤
|
||||||
|
//
|
||||||
|
// 注意:caller 拿到 conn 後**必須自己負責 Close**;本函式內部不會 set deadline,
|
||||||
|
// 因為 streaming 場景(MJPEG / SSE)需要無限長的存活時間。
|
||||||
|
func (f *Forwarder) OpenStream(ctx context.Context, sessionToken string) (net.Conn, error) {
|
||||||
|
if f.proxyHost == "" {
|
||||||
|
return nil, errors.New("session: forwarder has no proxy host (check VISIONA_PROXY_INTERNAL_URL)")
|
||||||
|
}
|
||||||
|
if sessionToken == "" {
|
||||||
|
return nil, errors.New("session: forwarder.OpenStream requires non-empty sessionToken")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1. raw TCP dial
|
||||||
|
conn, err := f.dialer.DialContext(ctx, "tcp", f.proxyHost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session: dial remote-proxy %s: %w", f.proxyHost, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 寫 POST /internal/forward/raw?token=...
|
||||||
|
// 仿 dialRawForward 測試 helper 的格式(見 internal/relay/integration_raw_test.go)。
|
||||||
|
reqLine := fmt.Sprintf(
|
||||||
|
"POST /internal/forward/raw?token=%s HTTP/1.1\r\n"+
|
||||||
|
"Host: %s\r\n"+
|
||||||
|
"Content-Length: 0\r\n"+
|
||||||
|
"\r\n",
|
||||||
|
url.QueryEscape(sessionToken), f.proxyHost,
|
||||||
|
)
|
||||||
|
// 設一個短的握手 deadline,避免 remote-proxy 假死時 hang 住。
|
||||||
|
if err := conn.SetWriteDeadline(time.Now().Add(defaultHandshakeTimeout)); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := conn.Write([]byte(reqLine)); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: write forward request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 讀握手 — 預期 "HTTP/1.1 200 Connected\r\n\r\n"
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(defaultHandshakeTimeout)); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
statusLine, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: read handshake status: %w", err)
|
||||||
|
}
|
||||||
|
statusLine = strings.TrimRight(statusLine, "\r\n")
|
||||||
|
|
||||||
|
// 解析 status code
|
||||||
|
// 格式:HTTP/1.1 200 Connected 或 HTTP/1.1 502 Bad Gateway
|
||||||
|
if !strings.HasPrefix(statusLine, "HTTP/1.1 200") {
|
||||||
|
// 非 200 → 把 body 讀出來幫 debug;常見:502 = TUNNEL_DISCONNECTED
|
||||||
|
bodyHint := drainAndPeek(reader)
|
||||||
|
_ = conn.Close()
|
||||||
|
|
||||||
|
// session 不存在的明確錯誤對應 ErrSessionNotFound
|
||||||
|
if strings.Contains(statusLine, "502") {
|
||||||
|
return nil, fmt.Errorf("%w: remote-proxy responded %q (body hint: %s)",
|
||||||
|
ErrSessionNotFound, statusLine, bodyHint)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("session: forward handshake failed: %q (body hint: %s)",
|
||||||
|
statusLine, bodyHint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 把握手後的 header 讀完(一直讀到空行)
|
||||||
|
for {
|
||||||
|
line, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: read handshake headers: %w", err)
|
||||||
|
}
|
||||||
|
if line == "\r\n" || line == "\n" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 清掉 deadline,因為後續 streaming 場景不該再 timeout
|
||||||
|
if err := conn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: clear deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 如果 reader 裡還有預讀資料(bufio.NewReader 可能讀超過一行),
|
||||||
|
// 回傳一個包裝 conn 把預讀的 byte 接回 stream。
|
||||||
|
// 這個情境在 raw forward 上理論上不會發生(remote-proxy 在發出
|
||||||
|
// "200 Connected\r\n\r\n" 之後不會主動寫資料 — 它要等 caller 寫
|
||||||
|
// request 才會從 yamux stream 收 response);但保險起見處理。
|
||||||
|
if buffered := reader.Buffered(); buffered > 0 {
|
||||||
|
peek, _ := reader.Peek(buffered)
|
||||||
|
f.logger.Warn("forwarder: unexpected bytes after handshake; wrapping conn",
|
||||||
|
"bytes", buffered)
|
||||||
|
return newPrefixConn(conn, append([]byte(nil), peek...)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardHTTP 是「給定 http.Request,回傳 *http.Response」的高階 helper。
|
||||||
|
//
|
||||||
|
// 內部實作:
|
||||||
|
// 1. OpenStream 拿 raw TCP(已 hijack)連線
|
||||||
|
// 2. req.Write(conn) 把完整 HTTP request 寫進去
|
||||||
|
// 3. http.ReadResponse 讀出 response(不消耗 body)
|
||||||
|
//
|
||||||
|
// 重要:response.Body **包住 conn 本身**(所以 caller 必須在用完後 Close
|
||||||
|
// response.Body);這允許 streaming body(MJPEG / SSE / chunked)原樣轉發。
|
||||||
|
//
|
||||||
|
// req 的 URL.Host / Scheme 會被覆寫成 "127.0.0.1" / "http",因為 local agent
|
||||||
|
// 收到的是「打到自己 localhost」的請求;caller 設定的 Host header 會被保留。
|
||||||
|
func (f *Forwarder) ForwardHTTP(ctx context.Context, sessionToken string, req *http.Request) (*http.Response, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, errors.New("session: ForwardHTTP requires non-nil req")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := f.OpenStream(ctx, sessionToken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 改寫 req 為「打給 local agent」格式:
|
||||||
|
// - URL.Scheme = http,URL.Host = 127.0.0.1 → req.Write 才不會報錯
|
||||||
|
// - RequestURI 必須清空(client 端不能設)
|
||||||
|
// - 不覆寫 req.Host:caller 自行決定要不要保留 browser 的 Host
|
||||||
|
//
|
||||||
|
// 注意:req 本身可能已被外部使用,這裡複製 URL 避免副作用。
|
||||||
|
outReq := req.Clone(ctx)
|
||||||
|
if outReq.URL == nil {
|
||||||
|
outReq.URL = &url.URL{}
|
||||||
|
}
|
||||||
|
outReq.URL.Scheme = "http"
|
||||||
|
outReq.URL.Host = "127.0.0.1"
|
||||||
|
outReq.RequestURI = ""
|
||||||
|
if outReq.Host == "" {
|
||||||
|
outReq.Host = "127.0.0.1"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 把 request 寫到 conn
|
||||||
|
if err := outReq.Write(conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: write request to forwarded conn: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 讀 response — 不可以 close conn,因為 response.Body 還會用到
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(conn), outReq)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("session: read response from forwarded conn: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 把 conn 包進 response.Body 的 close chain:caller close body 時連 conn 一起關
|
||||||
|
resp.Body = &bodyWithConn{ReadCloser: resp.Body, conn: conn}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardWebSocket 預留 — B5 接前端 WS 時實作。
|
||||||
|
//
|
||||||
|
// 預期實作(草稿):
|
||||||
|
// - OpenStream 拿到 raw conn
|
||||||
|
// - 把 WS upgrade request 透過 conn 寫過去
|
||||||
|
// - 等 101 response 回來
|
||||||
|
// - Hijack browser 端連線,與 conn 雙向 pipe
|
||||||
|
//
|
||||||
|
// 雛形先回 ErrNotImplemented,避免被誤用。
|
||||||
|
func (f *Forwarder) ForwardWebSocket(ctx context.Context, sessionToken string, req *http.Request) (net.Conn, error) {
|
||||||
|
return nil, errors.New("session: ForwardWebSocket not implemented yet (TODO B5)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// drainAndPeek 嘗試讀少量 byte 給 error message 加上 context;
|
||||||
|
// 不阻塞太久,最多 256 byte。
|
||||||
|
//
|
||||||
|
// 呼叫前提:caller 必須已經對 underlying conn 設過 ReadDeadline(這個函式只
|
||||||
|
// 在 OpenStream 握手失敗的 error path 被呼叫,該路徑已經 SetReadDeadline
|
||||||
|
// 到 defaultHandshakeTimeout),所以 Read 不會 hang 住;若 deadline 已過,
|
||||||
|
// Read 會立刻回 0 + deadline error,行為仍然是「不阻塞」。
|
||||||
|
func drainAndPeek(reader *bufio.Reader) string {
|
||||||
|
buf := make([]byte, 256)
|
||||||
|
n, _ := reader.Read(buf)
|
||||||
|
return strings.TrimSpace(string(buf[:n]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// bodyWithConn 把 ReadCloser 與底層 net.Conn 綁在一起,
|
||||||
|
// caller close body 時順便關 conn(避免 leak)。
|
||||||
|
type bodyWithConn struct {
|
||||||
|
io.ReadCloser
|
||||||
|
conn net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 同時關閉 body 與底層 conn;以最後一個非 nil 的 error 回傳。
|
||||||
|
func (b *bodyWithConn) Close() error {
|
||||||
|
bodyErr := b.ReadCloser.Close()
|
||||||
|
connErr := b.conn.Close()
|
||||||
|
if bodyErr != nil {
|
||||||
|
return bodyErr
|
||||||
|
}
|
||||||
|
return connErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// prefixConn 把預讀的 byte 接回 net.Conn 開頭,供 caller 透明使用。
|
||||||
|
//
|
||||||
|
// 並發說明:net.Conn 本身對單一 goroutine 讀 + 單一 goroutine 寫是安全的。
|
||||||
|
// prefixConn 只包裝 Read;prefix 的讀取不會跨 goroutine 共享(Read 慣例上
|
||||||
|
// 只由 reader goroutine 呼叫),所以這裡不需要額外的 mutex。
|
||||||
|
type prefixConn struct {
|
||||||
|
net.Conn
|
||||||
|
prefix []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newPrefixConn(c net.Conn, prefix []byte) *prefixConn {
|
||||||
|
return &prefixConn{Conn: c, prefix: prefix}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *prefixConn) Read(b []byte) (int, error) {
|
||||||
|
if len(p.prefix) > 0 {
|
||||||
|
n := copy(b, p.prefix)
|
||||||
|
p.prefix = p.prefix[n:]
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return p.Conn.Read(b)
|
||||||
|
}
|
||||||
90
visionA-backend/internal/session/forwarder_test.go
Normal file
90
visionA-backend/internal/session/forwarder_test.go
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestForwarder_OpenStream_NoProxyHost 驗證 baseURL 為空時直接拒絕。
|
||||||
|
func TestForwarder_OpenStream_NoProxyHost(t *testing.T) {
|
||||||
|
f := NewForwarder("", nil)
|
||||||
|
_, err := f.OpenStream(context.Background(), "vAc_x")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestForwarder_OpenStream_EmptyToken 驗證空 token 拒絕。
|
||||||
|
func TestForwarder_OpenStream_EmptyToken(t *testing.T) {
|
||||||
|
f := NewForwarder("http://localhost:9999", nil)
|
||||||
|
_, err := f.OpenStream(context.Background(), "")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestForwarder_ForwardWebSocket_NotImplemented 驗證 ForwardWebSocket 仍是 stub。
|
||||||
|
func TestForwarder_ForwardWebSocket_NotImplemented(t *testing.T) {
|
||||||
|
f := NewForwarder("http://localhost:9999", nil)
|
||||||
|
req, _ := http.NewRequest(http.MethodGet, "/ws", nil)
|
||||||
|
_, err := f.ForwardWebSocket(context.Background(), "vAc_x", req)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestForwarder_OpenStream_502_TreatedAsNotFound 驗證當 remote-proxy 回 502
|
||||||
|
// (session 不存在時的雛形行為)→ 包裝成 ErrSessionNotFound。
|
||||||
|
//
|
||||||
|
// 用 httptest 起一個假的 internal endpoint,回 502 JSON。
|
||||||
|
func TestForwarder_OpenStream_502_TreatedAsNotFound(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
_, _ = w.Write([]byte(`{"error":{"code":"TUNNEL_DISCONNECTED","message":"session not connected"}}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
f := NewForwarder(ts.URL, nil)
|
||||||
|
_, err := f.OpenStream(context.Background(), "vAc_dead")
|
||||||
|
if !errors.Is(err, ErrSessionNotFound) {
|
||||||
|
t.Fatalf("expected ErrSessionNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestForwarder_OpenStream_HandshakeRead 驗證能正確讀「HTTP/1.1 200 Connected\r\n\r\n」
|
||||||
|
// 握手;用一個假 server 回正確握手後立刻 close — 期望我們的 OpenStream 成功,
|
||||||
|
// 後續 Read 拿 EOF(這對 forwarder 而言是合法情境,由 caller 處理)。
|
||||||
|
//
|
||||||
|
// 此 case 直接驗證 happy-path 握手解析;真正的端對端轉發由 integration test 涵蓋。
|
||||||
|
func TestForwarder_OpenStream_HandshakeRead(t *testing.T) {
|
||||||
|
// 為了保證 server 端在 200 Connected 後不再寫 body(讓 forwarder 結束 header 讀
|
||||||
|
// 不被預讀干擾),用一個 raw TCP listener 而非 httptest.NewServer。
|
||||||
|
// 但 raw listener 會增加測試複雜度;在 unit test 用 httptest 已足以驗證
|
||||||
|
// 「能 parse 200 Connected + 兩個 \r\n」的路徑——讀 body 結束會回 EOF,
|
||||||
|
// 後續 caller 用該 conn 才會發現問題,這裡僅驗證 OpenStream 不 error。
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 不能直接寫 raw "HTTP/1.1 200 Connected\r\n\r\n" — httptest 會額外加
|
||||||
|
// content-length 等 header。改用 hijack 模擬真實 raw forward 行為。
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "no hijacker", 500)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
_, _ = conn.Write([]byte("HTTP/1.1 200 Connected\r\n\r\n"))
|
||||||
|
// 不再寫;讓 forwarder 拿到 conn 後若 read 會 EOF
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
f := NewForwarder(ts.URL, nil)
|
||||||
|
conn, err := f.OpenStream(context.Background(), "vAc_x")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("OpenStream should succeed: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
// 不再做 read 驗證(行為由 integration test 涵蓋)
|
||||||
|
}
|
||||||
132
visionA-backend/internal/session/inmemory_store.go
Normal file
132
visionA-backend/internal/session/inmemory_store.go
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// InMemoryStore 是 Store 的單節點記憶體實作,只部署在 remote-proxy binary。
|
||||||
|
//
|
||||||
|
// 語意重點(對齊 tunnel.md §2.3 / §5.2 + Q5 裁決):
|
||||||
|
// - 同 token 後連覆蓋前連 — Register 時若已存在,先 Close() 舊 handle 再寫入。
|
||||||
|
// - Heartbeat 僅更新 Summary.LastHeartbeat 時間戳。
|
||||||
|
// - CleanupExpired 掃描並移除逾時者(同時 Close 對應 handle)。
|
||||||
|
// - 所有操作並發安全(sync.RWMutex)。
|
||||||
|
type InMemoryStore struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
sessions map[string]Handle
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewInMemoryStore 建立一個空的記憶體 session store。
|
||||||
|
func NewInMemoryStore() *InMemoryStore {
|
||||||
|
return &InMemoryStore{
|
||||||
|
sessions: make(map[string]Handle),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register 註冊一個 session;同 token 的舊 session 會先被 Close 再覆蓋(Q5 裁決)。
|
||||||
|
func (s *InMemoryStore) Register(ctx context.Context, token string, h Handle) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if old, ok := s.sessions[token]; ok {
|
||||||
|
// 後連覆蓋前連:關閉舊的 handle 以釋放 yamux / WS 資源。
|
||||||
|
// Close 錯誤忽略 — 舊的可能已經斷線,這不影響新連線的註冊。
|
||||||
|
_ = old.Close()
|
||||||
|
}
|
||||||
|
s.sessions[token] = h
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister 移除指定 token;不存在為 no-op。
|
||||||
|
func (s *InMemoryStore) Unregister(ctx context.Context, token string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
delete(s.sessions, token)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup 回傳指定 token 的 handle;不存在回 ErrSessionNotFound。
|
||||||
|
func (s *InMemoryStore) Lookup(ctx context.Context, token string) (Handle, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
h, ok := s.sessions[token]
|
||||||
|
if !ok {
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 判斷 token 是否有 active session。
|
||||||
|
func (s *InMemoryStore) Exists(ctx context.Context, token string) (bool, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
_, ok := s.sessions[token]
|
||||||
|
return ok, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 回傳所有 active session 的 summary。
|
||||||
|
func (s *InMemoryStore) List(ctx context.Context) ([]*Summary, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
out := make([]*Summary, 0, len(s.sessions))
|
||||||
|
for _, h := range s.sessions {
|
||||||
|
if sum := h.Summary(); sum != nil {
|
||||||
|
// 複製 Summary 避免 caller 誤改內部狀態
|
||||||
|
cp := *sum
|
||||||
|
out = append(out, &cp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat 更新 session 的 LastHeartbeat 時間。
|
||||||
|
//
|
||||||
|
// 修 B2 Review M1(race condition):改為呼叫 Handle.RecordHeartbeat,
|
||||||
|
// 由各實作自行用 mutex / atomic 保護 LastHeartbeat 欄位,
|
||||||
|
// 避免 Store.Heartbeat 與 Store.CleanupExpired / Store.List 對同一 Summary pointer
|
||||||
|
// 的並發讀寫被 race detector 捕捉。
|
||||||
|
func (s *InMemoryStore) Heartbeat(ctx context.Context, token string) error {
|
||||||
|
s.mu.RLock()
|
||||||
|
h, ok := s.sessions[token]
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return ErrSessionNotFound
|
||||||
|
}
|
||||||
|
h.RecordHeartbeat(time.Now().UTC())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpired 清除 LastHeartbeat 超過 expireAfter 的 session。
|
||||||
|
//
|
||||||
|
// 實作步驟:
|
||||||
|
// 1. 在讀鎖下找出所有過期 token(避免長時間持寫鎖)
|
||||||
|
// 2. 升級為寫鎖,逐一移除(二次檢查避免 race)
|
||||||
|
// 3. Close 對應 handle 釋放資源
|
||||||
|
func (s *InMemoryStore) CleanupExpired(ctx context.Context, expireAfter time.Duration) (int, error) {
|
||||||
|
cutoff := time.Now().UTC().Add(-expireAfter)
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
removed := 0
|
||||||
|
for token, h := range s.sessions {
|
||||||
|
sum := h.Summary()
|
||||||
|
if sum == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sum.LastHeartbeat.Before(cutoff) {
|
||||||
|
_ = h.Close()
|
||||||
|
delete(s.sessions, token)
|
||||||
|
removed++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 InMemoryStore 實作 Store。
|
||||||
|
var _ Store = (*InMemoryStore)(nil)
|
||||||
293
visionA-backend/internal/session/inmemory_store_test.go
Normal file
293
visionA-backend/internal/session/inmemory_store_test.go
Normal file
@ -0,0 +1,293 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeHandle 是測試用 Handle 實作,不涉及真實網路。
|
||||||
|
//
|
||||||
|
// 為配合 B2 Review M1 修補,fakeHandle 以 mutex 保護 summary 的
|
||||||
|
// LastHeartbeat 欄位(Summary() 回傳快照、RecordHeartbeat 在 lock 下寫入)。
|
||||||
|
type fakeHandle struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
summary Summary
|
||||||
|
closed atomic.Bool
|
||||||
|
closeErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeHandle(token, userID, deviceID string) *fakeHandle {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
return &fakeHandle{
|
||||||
|
summary: Summary{
|
||||||
|
Token: token,
|
||||||
|
UserID: userID,
|
||||||
|
DeviceID: deviceID,
|
||||||
|
ConnectedAt: now,
|
||||||
|
LastHeartbeat: now,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *fakeHandle) OpenStream(ctx context.Context) (net.Conn, error) {
|
||||||
|
if h.closed.Load() {
|
||||||
|
return nil, ErrSessionClosed
|
||||||
|
}
|
||||||
|
return nil, errors.New("fakeHandle: OpenStream not implemented for tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *fakeHandle) Close() error {
|
||||||
|
h.closed.Store(true)
|
||||||
|
return h.closeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *fakeHandle) IsClosed() bool {
|
||||||
|
return h.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *fakeHandle) Summary() *Summary {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
cp := h.summary
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *fakeHandle) RecordHeartbeat(t time.Time) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.summary.LastHeartbeat = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// setLastHeartbeatForTest 僅供測試直接覆寫 LastHeartbeat(CleanupExpired 測試用)。
|
||||||
|
func (h *fakeHandle) setLastHeartbeatForTest(t time.Time) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.summary.LastHeartbeat = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestInMemoryStore_RegisterAndLookup(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
h := newFakeHandle("tok-1", "user-1", "dev-1")
|
||||||
|
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", h))
|
||||||
|
|
||||||
|
got, err := s.Lookup(ctx, "tok-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, h, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Lookup_NotFound(t *testing.T) {
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
_, err := s.Lookup(context.Background(), "tok-unknown")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Register_OverwritesAndClosesOld(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
old := newFakeHandle("tok-1", "user-1", "dev-1")
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", old))
|
||||||
|
|
||||||
|
// 後連覆蓋前連(Q5)
|
||||||
|
newHandle := newFakeHandle("tok-1", "user-1", "dev-1")
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", newHandle))
|
||||||
|
|
||||||
|
// 舊 handle 應被 Close
|
||||||
|
assert.True(t, old.IsClosed(), "舊 handle 應該被 Close")
|
||||||
|
assert.False(t, newHandle.IsClosed(), "新 handle 不應被 Close")
|
||||||
|
|
||||||
|
// Lookup 回傳新的
|
||||||
|
got, err := s.Lookup(ctx, "tok-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, newHandle, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Exists(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
ok, err := s.Exists(ctx, "tok-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", newFakeHandle("tok-1", "u", "d")))
|
||||||
|
|
||||||
|
ok, err = s.Exists(ctx, "tok-1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Unregister(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", newFakeHandle("tok-1", "u", "d")))
|
||||||
|
|
||||||
|
require.NoError(t, s.Unregister(ctx, "tok-1"))
|
||||||
|
|
||||||
|
ok, _ := s.Exists(ctx, "tok-1")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
// 不存在的 token 不應回錯
|
||||||
|
assert.NoError(t, s.Unregister(ctx, "tok-unknown"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Heartbeat_UpdatesLastHeartbeat(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
h := newFakeHandle("tok-1", "u", "d")
|
||||||
|
start := h.Summary().LastHeartbeat
|
||||||
|
|
||||||
|
require.NoError(t, s.Register(ctx, "tok-1", h))
|
||||||
|
|
||||||
|
// 確保時間差
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
|
||||||
|
require.NoError(t, s.Heartbeat(ctx, "tok-1"))
|
||||||
|
|
||||||
|
after := h.Summary().LastHeartbeat
|
||||||
|
assert.True(t, after.After(start), "LastHeartbeat 應該被更新:%v > %v", after, start)
|
||||||
|
// (修 B2 M1)Heartbeat 走 RecordHeartbeat;race detector 必須通過。
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_Heartbeat_NotFound(t *testing.T) {
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
err := s.Heartbeat(context.Background(), "tok-unknown")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_List(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
require.NoError(t, s.Register(ctx, "a", newFakeHandle("a", "u1", "d1")))
|
||||||
|
require.NoError(t, s.Register(ctx, "b", newFakeHandle("b", "u2", "d2")))
|
||||||
|
|
||||||
|
summaries, err := s.List(ctx)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, summaries, 2)
|
||||||
|
|
||||||
|
tokens := map[string]bool{}
|
||||||
|
for _, sum := range summaries {
|
||||||
|
tokens[sum.Token] = true
|
||||||
|
}
|
||||||
|
assert.True(t, tokens["a"])
|
||||||
|
assert.True(t, tokens["b"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_CleanupExpired(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
// 手動設定 LastHeartbeat 為過去時間
|
||||||
|
old := newFakeHandle("expired", "u", "d")
|
||||||
|
old.setLastHeartbeatForTest(time.Now().UTC().Add(-1 * time.Minute))
|
||||||
|
|
||||||
|
fresh := newFakeHandle("fresh", "u", "d")
|
||||||
|
// fresh.summary.LastHeartbeat 已在 newFakeHandle 設為 now
|
||||||
|
|
||||||
|
require.NoError(t, s.Register(ctx, "expired", old))
|
||||||
|
require.NoError(t, s.Register(ctx, "fresh", fresh))
|
||||||
|
|
||||||
|
// 以 30s 為 expireAfter,expired 超過 60s 應被清
|
||||||
|
removed, err := s.CleanupExpired(ctx, 30*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 1, removed)
|
||||||
|
|
||||||
|
assert.True(t, old.IsClosed(), "逾時的 handle 應該被 Close")
|
||||||
|
assert.False(t, fresh.IsClosed())
|
||||||
|
|
||||||
|
ok, _ := s.Exists(ctx, "expired")
|
||||||
|
assert.False(t, ok)
|
||||||
|
ok, _ = s.Exists(ctx, "fresh")
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInMemoryStore_CleanupExpired_Empty(t *testing.T) {
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
removed, err := s.CleanupExpired(context.Background(), 30*time.Second)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestInMemoryStore_Heartbeat_CleanupExpired_NoRace 驗證 B2 Review M1 修補:
|
||||||
|
// 並發執行 Heartbeat(寫)與 CleanupExpired / List(讀)時 race detector 不應捕捉到衝突。
|
||||||
|
// 本測試應在 `go test -race` 下通過。
|
||||||
|
func TestInMemoryStore_Heartbeat_CleanupExpired_NoRace(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := NewInMemoryStore()
|
||||||
|
|
||||||
|
// 註冊 20 個 session
|
||||||
|
const n = 20
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tok := "tok-" + string(rune('a'+i))
|
||||||
|
require.NoError(t, s.Register(ctx, tok, newFakeHandle(tok, "u", "d")))
|
||||||
|
}
|
||||||
|
|
||||||
|
stop := make(chan struct{})
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// 並發跑 Heartbeat
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
tok := "tok-" + string(rune('a'+i))
|
||||||
|
_ = s.Heartbeat(ctx, tok)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 並發跑 CleanupExpired(不真的清掉,因為 expireAfter 很大)
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
_, _ = s.CleanupExpired(ctx, 1*time.Hour)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 並發跑 List
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stop:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
_, _ = s.List(ctx)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 跑 100ms 讓 race detector 有足夠機會採樣
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
close(stop)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
209
visionA-backend/internal/session/proxy_client.go
Normal file
209
visionA-backend/internal/session/proxy_client.go
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
// proxy_client.go — HTTPProxyClient 實作 ProxyClient interface。
|
||||||
|
//
|
||||||
|
// HTTPProxyClient 是 api-server 端透過 internal HTTP API 存取 remote-proxy 的客戶端。
|
||||||
|
// 對齊 `.autoflow/04-architecture/api/api-internal.md` 的端點規格:
|
||||||
|
//
|
||||||
|
// - GET /internal/session/:token → GetSession
|
||||||
|
// - GET /internal/sessions → ListSessions
|
||||||
|
// - POST /internal/session/:token/close → CloseSession
|
||||||
|
//
|
||||||
|
// 實際的「打開 stream 並轉發 HTTP request」走 raw forward 路徑(見 forwarder.go),
|
||||||
|
// 不在此 client 範圍內 — 這個 client 只負責純 metadata 操作。
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// defaultProxyClientTimeout 是 internal HTTP 呼叫的預設 timeout。
|
||||||
|
//
|
||||||
|
// 30s 對 internal 網路(同機 / 同 VPC)已綽綽有餘;
|
||||||
|
// 真正的 streaming 走 forward/raw 不走此 client,所以不需要無限長。
|
||||||
|
const defaultProxyClientTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// HTTPProxyClient 是 ProxyClient 的 HTTP 實作。
|
||||||
|
//
|
||||||
|
// 並發安全:依賴 http.Client,本身為 stateless(baseURL / logger / timeout 在建構時固定)。
|
||||||
|
type HTTPProxyClient struct {
|
||||||
|
baseURL string // remote-proxy internal URL,例:http://localhost:3801
|
||||||
|
http *http.Client // 共用一個 http.Client 以重用 keep-alive 連線
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPProxyClient 建立一個新的 HTTPProxyClient。
|
||||||
|
//
|
||||||
|
// baseURL 必須為合法 URL,否則 caller 在第一次呼叫時才會發現錯誤;
|
||||||
|
// 為避免「沉默失敗」,這裡會在建構時 trim 尾端 "/"。
|
||||||
|
//
|
||||||
|
// logger 為 nil 時使用 slog.Default。
|
||||||
|
func NewHTTPProxyClient(baseURL string, logger *slog.Logger) *HTTPProxyClient {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return &HTTPProxyClient{
|
||||||
|
baseURL: strings.TrimRight(baseURL, "/"),
|
||||||
|
http: &http.Client{
|
||||||
|
Timeout: defaultProxyClientTimeout,
|
||||||
|
},
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BaseURL 回傳建構時設定的 remote-proxy internal URL(給 forwarder 共用)。
|
||||||
|
func (c *HTTPProxyClient) BaseURL() string {
|
||||||
|
return c.baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSession 對應 GET /internal/session/:token。
|
||||||
|
//
|
||||||
|
// 對 remote-proxy 的回應格式(由 internal_forward.go.getSession 寫入):
|
||||||
|
//
|
||||||
|
// {
|
||||||
|
// "token": "vAc_...",
|
||||||
|
// "connected": true,
|
||||||
|
// "connected_at": "RFC3339",
|
||||||
|
// "last_heartbeat": "RFC3339",
|
||||||
|
// "remote_addr": "1.2.3.4:5678",
|
||||||
|
// "user_id": "demo-user",
|
||||||
|
// "device_id": ""
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// 回傳 *Summary;session 不存在時回 ErrSessionNotFound(HTTP 404)。
|
||||||
|
func (c *HTTPProxyClient) GetSession(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
if token == "" {
|
||||||
|
return nil, errors.New("session: GetSession requires non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := c.baseURL + "/internal/session/" + url.PathEscape(token)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session: build GetSession request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session: GetSession call remote-proxy failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
// 解析 remote-proxy 的 JSON
|
||||||
|
var raw struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
Connected bool `json:"connected"`
|
||||||
|
ConnectedAt time.Time `json:"connected_at"`
|
||||||
|
LastHeartbeat time.Time `json:"last_heartbeat"`
|
||||||
|
RemoteAddr string `json:"remote_addr"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
DeviceID string `json:"device_id"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("session: GetSession decode response: %w", err)
|
||||||
|
}
|
||||||
|
// connected=false 視為 NotFound(已斷線或正在清理)
|
||||||
|
if !raw.Connected {
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
}
|
||||||
|
return &Summary{
|
||||||
|
Token: raw.Token,
|
||||||
|
UserID: raw.UserID,
|
||||||
|
DeviceID: raw.DeviceID,
|
||||||
|
ConnectedAt: raw.ConnectedAt,
|
||||||
|
LastHeartbeat: raw.LastHeartbeat,
|
||||||
|
RemoteAddr: raw.RemoteAddr,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 不是已知 status — 帶上 body 讓使用端 debug
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||||
|
return nil, fmt.Errorf("session: GetSession unexpected status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListSessions 對應 GET /internal/sessions。
|
||||||
|
//
|
||||||
|
// remote-proxy 回應格式(internal_forward.go.HandleListSessions):
|
||||||
|
//
|
||||||
|
// { "sessions": [ Summary, ... ], "total": N }
|
||||||
|
func (c *HTTPProxyClient) ListSessions(ctx context.Context) ([]*Summary, error) {
|
||||||
|
endpoint := c.baseURL + "/internal/sessions"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session: build ListSessions request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("session: ListSessions call remote-proxy failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||||
|
return nil, fmt.Errorf("session: ListSessions unexpected status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var raw struct {
|
||||||
|
Sessions []*Summary `json:"sessions"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
if err := json.NewDecoder(resp.Body).Decode(&raw); err != nil {
|
||||||
|
return nil, fmt.Errorf("session: ListSessions decode response: %w", err)
|
||||||
|
}
|
||||||
|
if raw.Sessions == nil {
|
||||||
|
// 空 list 統一用 non-nil empty slice,呼叫方好處理
|
||||||
|
return []*Summary{}, nil
|
||||||
|
}
|
||||||
|
return raw.Sessions, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseSession 對應 POST /internal/session/:token/close。
|
||||||
|
//
|
||||||
|
// 用於管理動作(使用者 revoke token、後台運維強制斷線)。
|
||||||
|
// session 不存在回 ErrSessionNotFound;其他錯誤直接 wrap。
|
||||||
|
func (c *HTTPProxyClient) CloseSession(ctx context.Context, token string) error {
|
||||||
|
if token == "" {
|
||||||
|
return errors.New("session: CloseSession requires non-empty token")
|
||||||
|
}
|
||||||
|
|
||||||
|
endpoint := c.baseURL + "/internal/session/" + url.PathEscape(token) + "/close"
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("session: build CloseSession request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.http.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("session: CloseSession call remote-proxy failed: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
switch resp.StatusCode {
|
||||||
|
case http.StatusOK:
|
||||||
|
// 消化 body 讓底層連線可以被 keep-alive 重用
|
||||||
|
_, _ = io.Copy(io.Discard, resp.Body)
|
||||||
|
return nil
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return ErrSessionNotFound
|
||||||
|
default:
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||||
|
return fmt.Errorf("session: CloseSession unexpected status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 HTTPProxyClient 實作 ProxyClient。
|
||||||
|
var _ ProxyClient = (*HTTPProxyClient)(nil)
|
||||||
145
visionA-backend/internal/session/proxy_client_test.go
Normal file
145
visionA-backend/internal/session/proxy_client_test.go
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_GetSession_OK 驗證能正確解析 remote-proxy 的
|
||||||
|
// /internal/session/:token 回應 → Summary。
|
||||||
|
func TestHTTPProxyClient_GetSession_OK(t *testing.T) {
|
||||||
|
now := time.Now().UTC().Truncate(time.Second)
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "/internal/session/vAc_abc", r.URL.Path)
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"token": "vAc_abc",
|
||||||
|
"connected": true,
|
||||||
|
"connected_at": now,
|
||||||
|
"last_heartbeat": now,
|
||||||
|
"remote_addr": "1.2.3.4:5678",
|
||||||
|
"user_id": "demo-user",
|
||||||
|
"device_id": "dev-1",
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
sum, err := c.GetSession(context.Background(), "vAc_abc")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, sum)
|
||||||
|
assert.Equal(t, "vAc_abc", sum.Token)
|
||||||
|
assert.Equal(t, "demo-user", sum.UserID)
|
||||||
|
assert.Equal(t, "dev-1", sum.DeviceID)
|
||||||
|
assert.Equal(t, "1.2.3.4:5678", sum.RemoteAddr)
|
||||||
|
assert.True(t, sum.LastHeartbeat.Equal(now))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_GetSession_NotFound 驗證 404 → ErrSessionNotFound。
|
||||||
|
func TestHTTPProxyClient_GetSession_NotFound(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, `{"error":"NOT_FOUND"}`, http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
_, err := c.GetSession(context.Background(), "vAc_xxx")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_GetSession_ConnectedFalse_TreatedAsNotFound
|
||||||
|
// 驗證 remote-proxy 回 connected=false(session 已被排隊清除)→ NotFound。
|
||||||
|
func TestHTTPProxyClient_GetSession_ConnectedFalse_TreatedAsNotFound(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"token": "vAc_dead",
|
||||||
|
"connected": false,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
_, err := c.GetSession(context.Background(), "vAc_dead")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_GetSession_EmptyToken 驗證空 token 直接被本地拒絕。
|
||||||
|
func TestHTTPProxyClient_GetSession_EmptyToken(t *testing.T) {
|
||||||
|
c := NewHTTPProxyClient("http://localhost:9999", nil)
|
||||||
|
_, err := c.GetSession(context.Background(), "")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_ListSessions_OK 驗證能正確 parse sessions array。
|
||||||
|
func TestHTTPProxyClient_ListSessions_OK(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, "/internal/sessions", r.URL.Path)
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"sessions": []map[string]any{
|
||||||
|
{"token": "vAc_a", "userId": "u1"},
|
||||||
|
{"token": "vAc_b", "userId": "u2"},
|
||||||
|
},
|
||||||
|
"total": 2,
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
sums, err := c.ListSessions(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, sums, 2)
|
||||||
|
assert.Equal(t, "vAc_a", sums[0].Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_ListSessions_Empty 驗證空 sessions 回 non-nil empty slice。
|
||||||
|
func TestHTTPProxyClient_ListSessions_Empty(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{"sessions": nil, "total": 0})
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
sums, err := c.ListSessions(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sums)
|
||||||
|
assert.Empty(t, sums)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_CloseSession_OK 驗證 200 → nil error。
|
||||||
|
func TestHTTPProxyClient_CloseSession_OK(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
assert.Equal(t, http.MethodPost, r.Method)
|
||||||
|
assert.Equal(t, "/internal/session/vAc_x/close", r.URL.Path)
|
||||||
|
_, _ = w.Write([]byte(`{"closed":true}`))
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
err := c.CloseSession(context.Background(), "vAc_x")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_CloseSession_NotFound 驗證 404 → ErrSessionNotFound。
|
||||||
|
func TestHTTPProxyClient_CloseSession_NotFound(t *testing.T) {
|
||||||
|
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
}))
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
c := NewHTTPProxyClient(ts.URL, nil)
|
||||||
|
err := c.CloseSession(context.Background(), "vAc_x")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestHTTPProxyClient_BaseURL_TrimsTrailingSlash 驗證 baseURL 結尾的 / 會被移除。
|
||||||
|
func TestHTTPProxyClient_BaseURL_TrimsTrailingSlash(t *testing.T) {
|
||||||
|
c := NewHTTPProxyClient("http://localhost:3801/", nil)
|
||||||
|
assert.Equal(t, "http://localhost:3801", c.BaseURL())
|
||||||
|
}
|
||||||
104
visionA-backend/internal/session/proxy_store.go
Normal file
104
visionA-backend/internal/session/proxy_store.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
// proxy_store.go — ProxyClientStore 實作 Store interface(api-server 端使用)。
|
||||||
|
//
|
||||||
|
// 雛形雙 binary 架構下:
|
||||||
|
// - remote-proxy 持有 InMemoryStore,是唯一的 session state 來源
|
||||||
|
// - api-server 持有 ProxyClientStore,內部透過 ProxyClient 走 internal HTTP 查 remote-proxy
|
||||||
|
//
|
||||||
|
// 因為 api-server 是無狀態,所以「寫入類」操作(Register / Heartbeat / CleanupExpired)
|
||||||
|
// 對 ProxyClientStore 都不適用 — 全部回 ErrNotSupported。
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyClientStore 是 Store 的 HTTP-client 實作,部署在 api-server 端。
|
||||||
|
//
|
||||||
|
// 它把所有讀取類操作 delegate 到 ProxyClient(HTTP 呼叫 remote-proxy)。
|
||||||
|
// 寫入類操作(Register / Unregister / Heartbeat / CleanupExpired)一律回
|
||||||
|
// ErrNotSupported — 因為 session lifecycle 由 remote-proxy 唯一管理。
|
||||||
|
//
|
||||||
|
// Lookup 回傳的 Handle 是 RemoteHandle(見下方),它的 OpenStream 會走
|
||||||
|
// `forwarder.go` 的 raw forward 流程。
|
||||||
|
type ProxyClientStore struct {
|
||||||
|
client ProxyClient
|
||||||
|
forwarder *Forwarder // 用於建立 RemoteHandle(OpenStream 時使用)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProxyClientStore 建立一個 api-server 端的 SessionStore。
|
||||||
|
//
|
||||||
|
// 入參:
|
||||||
|
// - client:用於 metadata 操作(GetSession / ListSessions / CloseSession)
|
||||||
|
// - forwarder:用於 RemoteHandle.OpenStream 走 raw forward
|
||||||
|
//
|
||||||
|
// forwarder 可為 nil(不需要 OpenStream,只要 metadata 查詢時);但實務上
|
||||||
|
// api-server 必定需要轉發,所以呼叫方應同時注入兩者。
|
||||||
|
func NewProxyClientStore(client ProxyClient, forwarder *Forwarder) *ProxyClientStore {
|
||||||
|
return &ProxyClientStore{client: client, forwarder: forwarder}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register — ProxyClientStore 不支援;session 註冊由 remote-proxy 在 tunnel
|
||||||
|
// upgrade 時完成。
|
||||||
|
func (s *ProxyClientStore) Register(ctx context.Context, token string, h Handle) error {
|
||||||
|
return ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unregister — 在 api-server 端等同於「強制關閉 session」,實際走
|
||||||
|
// CloseSession HTTP endpoint;不存在時為 no-op(對齊 InMemoryStore 行為)。
|
||||||
|
func (s *ProxyClientStore) Unregister(ctx context.Context, token string) error {
|
||||||
|
if err := s.client.CloseSession(ctx, token); err != nil {
|
||||||
|
// 不存在當作 no-op,與 InMemoryStore 一致
|
||||||
|
if errors.Is(err, ErrSessionNotFound) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lookup 對應 ProxyClient.GetSession,回傳 RemoteHandle。
|
||||||
|
//
|
||||||
|
// RemoteHandle 不持有 yamux session(它在 remote-proxy 那邊);
|
||||||
|
// 它的 OpenStream 會透過 Forwarder 走 raw forward。
|
||||||
|
func (s *ProxyClientStore) Lookup(ctx context.Context, token string) (Handle, error) {
|
||||||
|
sum, err := s.client.GetSession(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return newRemoteHandle(s.client, s.forwarder, sum), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 透過 GetSession 判斷;不存在回 (false, nil),其他錯誤回 (false, err)。
|
||||||
|
func (s *ProxyClientStore) Exists(ctx context.Context, token string) (bool, error) {
|
||||||
|
_, err := s.client.GetSession(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, ErrSessionNotFound) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 對應 ProxyClient.ListSessions。
|
||||||
|
func (s *ProxyClientStore) List(ctx context.Context) ([]*Summary, error) {
|
||||||
|
return s.client.ListSessions(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Heartbeat — ProxyClientStore 不支援;心跳由 yamux 的 keep-alive 自動維持,
|
||||||
|
// 並由 remote-proxy 在實體 tunnel 上更新 LastHeartbeat。
|
||||||
|
func (s *ProxyClientStore) Heartbeat(ctx context.Context, token string) error {
|
||||||
|
return ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpired — ProxyClientStore 不支援;清理由 remote-proxy 的
|
||||||
|
// background goroutine 執行。
|
||||||
|
func (s *ProxyClientStore) CleanupExpired(ctx context.Context, expireAfter time.Duration) (int, error) {
|
||||||
|
return 0, ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 ProxyClientStore 實作 Store。
|
||||||
|
var _ Store = (*ProxyClientStore)(nil)
|
||||||
180
visionA-backend/internal/session/proxy_store_test.go
Normal file
180
visionA-backend/internal/session/proxy_store_test.go
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// fakeProxyClient 是測試用 ProxyClient mock;以 lambda 注入行為。
|
||||||
|
type fakeProxyClient struct {
|
||||||
|
getSessionFn func(ctx context.Context, token string) (*Summary, error)
|
||||||
|
listFn func(ctx context.Context) ([]*Summary, error)
|
||||||
|
closeSessionFn func(ctx context.Context, token string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeProxyClient) GetSession(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
return f.getSessionFn(ctx, token)
|
||||||
|
}
|
||||||
|
func (f *fakeProxyClient) ListSessions(ctx context.Context) ([]*Summary, error) {
|
||||||
|
return f.listFn(ctx)
|
||||||
|
}
|
||||||
|
func (f *fakeProxyClient) CloseSession(ctx context.Context, token string) error {
|
||||||
|
return f.closeSessionFn(ctx, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyClientStore_WriteOps_Unsupported 驗證所有寫入類操作回 ErrNotSupported。
|
||||||
|
func TestProxyClientStore_WriteOps_Unsupported(t *testing.T) {
|
||||||
|
store := NewProxyClientStore(&fakeProxyClient{}, nil)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
err := store.Register(ctx, "vAc_x", nil)
|
||||||
|
assert.ErrorIs(t, err, ErrNotSupported, "Register 必須回 ErrNotSupported")
|
||||||
|
|
||||||
|
err = store.Heartbeat(ctx, "vAc_x")
|
||||||
|
assert.ErrorIs(t, err, ErrNotSupported, "Heartbeat 必須回 ErrNotSupported")
|
||||||
|
|
||||||
|
_, err = store.CleanupExpired(ctx, time.Minute)
|
||||||
|
assert.ErrorIs(t, err, ErrNotSupported, "CleanupExpired 必須回 ErrNotSupported")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyClientStore_Lookup_OK 驗證 Lookup 走 client.GetSession 並回 RemoteHandle。
|
||||||
|
func TestProxyClientStore_Lookup_OK(t *testing.T) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
getSessionFn: func(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
assert.Equal(t, "vAc_x", token)
|
||||||
|
return &Summary{Token: token, ConnectedAt: now, LastHeartbeat: now}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
|
||||||
|
h, err := store.Lookup(context.Background(), "vAc_x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, h)
|
||||||
|
assert.False(t, h.IsClosed())
|
||||||
|
assert.Equal(t, "vAc_x", h.Summary().Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyClientStore_Lookup_NotFound 驗證 ErrSessionNotFound 透傳。
|
||||||
|
func TestProxyClientStore_Lookup_NotFound(t *testing.T) {
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
getSessionFn: func(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
_, err := store.Lookup(context.Background(), "vAc_x")
|
||||||
|
assert.ErrorIs(t, err, ErrSessionNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyClientStore_Exists 驗證 Exists 的兩種狀態。
|
||||||
|
func TestProxyClientStore_Exists(t *testing.T) {
|
||||||
|
t.Run("exists", func(t *testing.T) {
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
getSessionFn: func(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
return &Summary{Token: token}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
ok, err := store.Exists(context.Background(), "vAc_x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not_found_returns_false_no_error", func(t *testing.T) {
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
getSessionFn: func(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
return nil, ErrSessionNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
ok, err := store.Exists(context.Background(), "vAc_x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("other_error_propagates", func(t *testing.T) {
|
||||||
|
boom := errors.New("network down")
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
getSessionFn: func(ctx context.Context, token string) (*Summary, error) {
|
||||||
|
return nil, boom
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
ok, err := store.Exists(context.Background(), "vAc_x")
|
||||||
|
assert.False(t, ok)
|
||||||
|
require.Error(t, err)
|
||||||
|
// Store 直接 propagate(不 wrap),所以 errors.Is 對 sentinel 應為 true。
|
||||||
|
assert.ErrorIs(t, err, boom)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestProxyClientStore_Unregister_DelegatesToClose 驗證 Unregister 走 CloseSession,
|
||||||
|
// 並且 NotFound 視為 no-op。
|
||||||
|
func TestProxyClientStore_Unregister_DelegatesToClose(t *testing.T) {
|
||||||
|
t.Run("delegates", func(t *testing.T) {
|
||||||
|
var called bool
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
closeSessionFn: func(ctx context.Context, token string) error {
|
||||||
|
called = true
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
err := store.Unregister(context.Background(), "vAc_x")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, called, "CloseSession 應被呼叫")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("not_found_is_noop", func(t *testing.T) {
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
closeSessionFn: func(ctx context.Context, token string) error {
|
||||||
|
return ErrSessionNotFound
|
||||||
|
},
|
||||||
|
}
|
||||||
|
store := NewProxyClientStore(client, nil)
|
||||||
|
err := store.Unregister(context.Background(), "vAc_x")
|
||||||
|
assert.NoError(t, err, "NotFound 應視為 no-op")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteHandle_OpenStream_NoForwarder 驗證沒注入 forwarder 時回 ErrNotSupported。
|
||||||
|
func TestRemoteHandle_OpenStream_NoForwarder(t *testing.T) {
|
||||||
|
h := newRemoteHandle(&fakeProxyClient{}, nil, &Summary{Token: "vAc_x"})
|
||||||
|
_, err := h.OpenStream(context.Background())
|
||||||
|
assert.ErrorIs(t, err, ErrNotSupported)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteHandle_Close_Idempotent 驗證 Close 多次只 trigger 一次 CloseSession。
|
||||||
|
func TestRemoteHandle_Close_Idempotent(t *testing.T) {
|
||||||
|
var calls int
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
closeSessionFn: func(ctx context.Context, token string) error {
|
||||||
|
calls++
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
h := newRemoteHandle(client, nil, &Summary{Token: "vAc_x"})
|
||||||
|
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
require.NoError(t, h.Close())
|
||||||
|
assert.Equal(t, 1, calls, "Close 多次應冪等:CloseSession 只被呼叫一次")
|
||||||
|
assert.True(t, h.IsClosed())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRemoteHandle_OpenStream_AfterClose 驗證 close 後 OpenStream 回 ErrSessionClosed。
|
||||||
|
func TestRemoteHandle_OpenStream_AfterClose(t *testing.T) {
|
||||||
|
client := &fakeProxyClient{
|
||||||
|
closeSessionFn: func(ctx context.Context, token string) error { return nil },
|
||||||
|
}
|
||||||
|
h := newRemoteHandle(client, nil, &Summary{Token: "vAc_x"})
|
||||||
|
_ = h.Close()
|
||||||
|
_, err := h.OpenStream(context.Background())
|
||||||
|
assert.ErrorIs(t, err, ErrSessionClosed)
|
||||||
|
}
|
||||||
101
visionA-backend/internal/session/remote_handle.go
Normal file
101
visionA-backend/internal/session/remote_handle.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
// remote_handle.go — RemoteHandle 是 ProxyClientStore.Lookup 回傳的 Handle 實作。
|
||||||
|
//
|
||||||
|
// 它代表「session 在 remote-proxy 那邊」。OpenStream 走 Forwarder 的 raw forward。
|
||||||
|
//
|
||||||
|
// 注意:RemoteHandle 不持有 yamux session(那在 remote-proxy 的記憶體裡),
|
||||||
|
// 所以許多語意(Close / IsClosed)行為上跟 LocalHandle 不太一樣:
|
||||||
|
// - Close 走 ProxyClient.CloseSession(HTTP 通知 remote-proxy 關閉)
|
||||||
|
// - IsClosed 只能根據 Lookup 時的 Summary 判斷,無即時感知能力
|
||||||
|
//
|
||||||
|
// 對 api-server handler 來說,這些差異是透明的 — 只要拿 handle.OpenStream
|
||||||
|
// 來開 stream 就好。
|
||||||
|
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RemoteHandle 是 api-server 端的 Handle 實作。
|
||||||
|
type RemoteHandle struct {
|
||||||
|
client ProxyClient
|
||||||
|
forwarder *Forwarder
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
summary Summary
|
||||||
|
|
||||||
|
// closed 用 atomic 避免每次 IsClosed 都 lock
|
||||||
|
closed atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// newRemoteHandle 建立一個 RemoteHandle;package-internal,由 ProxyClientStore 用。
|
||||||
|
func newRemoteHandle(client ProxyClient, forwarder *Forwarder, sum *Summary) *RemoteHandle {
|
||||||
|
h := &RemoteHandle{
|
||||||
|
client: client,
|
||||||
|
forwarder: forwarder,
|
||||||
|
}
|
||||||
|
if sum != nil {
|
||||||
|
h.summary = *sum
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenStream 走 Forwarder 開一條 raw TCP(hijack)連線。
|
||||||
|
//
|
||||||
|
// 若 Forwarder 為 nil(僅 metadata-only 場景)回明確錯誤。
|
||||||
|
func (h *RemoteHandle) OpenStream(ctx context.Context) (net.Conn, error) {
|
||||||
|
if h.closed.Load() {
|
||||||
|
return nil, ErrSessionClosed
|
||||||
|
}
|
||||||
|
if h.forwarder == nil {
|
||||||
|
return nil, ErrNotSupported
|
||||||
|
}
|
||||||
|
return h.forwarder.OpenStream(ctx, h.summary.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 透過 ProxyClient 通知 remote-proxy 關閉這個 session。
|
||||||
|
//
|
||||||
|
// 冪等:多次呼叫只會打一次 HTTP(後續直接回 nil)。
|
||||||
|
func (h *RemoteHandle) Close() error {
|
||||||
|
if !h.closed.CompareAndSwap(false, true) {
|
||||||
|
return nil // 已經 close 過
|
||||||
|
}
|
||||||
|
if h.client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
// 用獨立的 ctx 避免 caller 取消後 close 半路斷掉
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
return h.client.CloseSession(ctx, h.summary.Token)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsClosed 回報是否已被本地 Close。
|
||||||
|
//
|
||||||
|
// 注意:不會即時感知 remote-proxy 那邊的斷線;
|
||||||
|
// 真正想確認 session 還在的話,應該再呼叫一次 Lookup。
|
||||||
|
func (h *RemoteHandle) IsClosed() bool {
|
||||||
|
return h.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Summary 回傳 session metadata 的 snapshot。
|
||||||
|
func (h *RemoteHandle) Summary() *Summary {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
cp := h.summary
|
||||||
|
return &cp
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordHeartbeat — RemoteHandle 不主動記錄心跳(心跳由 yamux 在 remote-proxy
|
||||||
|
// 端維護);本方法純粹更新本地 summary 的 LastHeartbeat 欄位以維持 interface 契約。
|
||||||
|
func (h *RemoteHandle) RecordHeartbeat(t time.Time) {
|
||||||
|
h.mu.Lock()
|
||||||
|
defer h.mu.Unlock()
|
||||||
|
h.summary.LastHeartbeat = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查
|
||||||
|
var _ Handle = (*RemoteHandle)(nil)
|
||||||
146
visionA-backend/internal/session/session.go
Normal file
146
visionA-backend/internal/session/session.go
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
// Package session 定義 tunnel session 管理介面與記憶體實作,對齊 tunnel.md §5。
|
||||||
|
//
|
||||||
|
// 在雛形雙 binary 架構下:
|
||||||
|
// - remote-proxy 端持有唯一的 InMemoryStore(真正 own *yamux.Session)。
|
||||||
|
// - api-server 端用 ProxyClientStore(透過 internal HTTP 查詢 remote-proxy)— 留給 B4 實作。
|
||||||
|
//
|
||||||
|
// 此 package 僅定義 interface 與 in-memory 實作;HTTP client 實作留給 B4。
|
||||||
|
package session
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Errors
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrSessionNotFound 表示指定 token 對應的 session 不存在(從未註冊或已移除)。
|
||||||
|
ErrSessionNotFound = errors.New("session: not found")
|
||||||
|
|
||||||
|
// ErrSessionExpired 表示 session 雖存在但 LastHeartbeat 已超過 IdleTimeout。
|
||||||
|
// 大多情境下 CleanupExpired 會先行移除,Lookup 會直接回 ErrSessionNotFound;
|
||||||
|
// 少數時序下 caller 可能碰到,保留明確語意便於除錯。
|
||||||
|
ErrSessionExpired = errors.New("session: expired")
|
||||||
|
|
||||||
|
// ErrNotSupported 表示此 Store 實作不支援該操作(例:ProxyClientStore.Register)。
|
||||||
|
ErrNotSupported = errors.New("session: operation not supported by this store")
|
||||||
|
|
||||||
|
// ErrSessionClosed 表示 session 底層連線已關閉,OpenStream 無法再開。
|
||||||
|
ErrSessionClosed = errors.New("session: underlying connection closed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Summary & Handle
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Summary 是 session 的可序列化描述,對 List / internal HTTP API 回傳。
|
||||||
|
type Summary struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
UserID string `json:"userId"`
|
||||||
|
DeviceID string `json:"deviceId,omitempty"`
|
||||||
|
ConnectedAt time.Time `json:"connectedAt"`
|
||||||
|
LastHeartbeat time.Time `json:"lastHeartbeat"`
|
||||||
|
RemoteAddr string `json:"remoteAddr,omitempty"`
|
||||||
|
ProxyNodeID string `json:"proxyNodeId,omitempty"` // Phase 1 多節點使用
|
||||||
|
ProxyInternalURL string `json:"proxyInternalUrl,omitempty"` // Phase 1 多節點使用
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle 是實際可操作的 session(綁在某個 proxy 節點的記憶體)。
|
||||||
|
//
|
||||||
|
// 雛形單節點:LocalHandle wrap *yamux.Session(B3 實作)。
|
||||||
|
// api-server 端:RemoteHandle wrap internal HTTP client(B4 實作)。
|
||||||
|
//
|
||||||
|
// 並發安全:
|
||||||
|
// - 實作必須自行保護 Summary().LastHeartbeat 的讀寫,
|
||||||
|
// 因為 Store.Heartbeat 會呼叫 RecordHeartbeat,而 Store.List / CleanupExpired
|
||||||
|
// 會透過 Summary() 讀取 LastHeartbeat。這是為了修 B2 Review M1 race。
|
||||||
|
type Handle interface {
|
||||||
|
// OpenStream 在此 session 上開一條新的雙向 stream。
|
||||||
|
// 若底層連線已關閉,回 ErrSessionClosed。
|
||||||
|
OpenStream(ctx context.Context) (net.Conn, error)
|
||||||
|
|
||||||
|
// Close 主動關閉此 session(通常由 remote-proxy 在 CleanupExpired 呼叫)。
|
||||||
|
Close() error
|
||||||
|
|
||||||
|
// IsClosed 回報底層連線是否已斷。
|
||||||
|
IsClosed() bool
|
||||||
|
|
||||||
|
// Summary 回傳 session 的可讀資訊(log / List 用)。
|
||||||
|
//
|
||||||
|
// 實作應回傳「內部 Summary 的副本」或「lock 保護下 snapshot」,
|
||||||
|
// 以避免 caller 觀察到中間態(例如 LastHeartbeat 正在被更新)。
|
||||||
|
Summary() *Summary
|
||||||
|
|
||||||
|
// RecordHeartbeat 更新此 session 的 LastHeartbeat 時間。
|
||||||
|
// 實作應以 mutex / atomic 保護,確保與 Summary() 的並發讀取安全(修 B2 M1)。
|
||||||
|
RecordHeartbeat(t time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Store interface
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Store 管理所有 active tunnel session。
|
||||||
|
//
|
||||||
|
// 對齊 tunnel.md §5.1 + Minor-4(CleanupExpired)+ interface-contracts.md §8.3。
|
||||||
|
// 實作必須是並發安全的。
|
||||||
|
type Store interface {
|
||||||
|
// Register 註冊一個 session handle;若 token 已存在,實作應**關閉舊 handle 並覆蓋**(Q5 裁決)。
|
||||||
|
Register(ctx context.Context, token string, h Handle) error
|
||||||
|
|
||||||
|
// Unregister 移除指定 token 的 session(通常 tunnel 斷線時呼叫)。
|
||||||
|
// 若不存在為 no-op,不回 error。
|
||||||
|
Unregister(ctx context.Context, token string) error
|
||||||
|
|
||||||
|
// Lookup 查詢 token 對應的 session handle;不存在回 ErrSessionNotFound。
|
||||||
|
Lookup(ctx context.Context, token string) (Handle, error)
|
||||||
|
|
||||||
|
// Exists 判斷指定 token 是否有 active session;不存在回 (false, nil),非 error。
|
||||||
|
Exists(ctx context.Context, token string) (bool, error)
|
||||||
|
|
||||||
|
// List 回傳所有 active session 的 summary。
|
||||||
|
List(ctx context.Context) ([]*Summary, error)
|
||||||
|
|
||||||
|
// Heartbeat 更新 session 的 LastHeartbeat 時間。
|
||||||
|
// 若 session 不存在回 ErrSessionNotFound。
|
||||||
|
Heartbeat(ctx context.Context, token string) error
|
||||||
|
|
||||||
|
// CleanupExpired 移除所有 LastHeartbeat 超過 expireAfter 的 session(Minor-4)。
|
||||||
|
//
|
||||||
|
// 實作須 Close() 對應 Handle 以釋放 yamux.Session / WS conn。
|
||||||
|
// 回傳被清理的 session 數量,供觀測。
|
||||||
|
//
|
||||||
|
// 由 remote-proxy 的 background goroutine 每 30s 呼叫一次(對齊 tunnel.md §4.2)。
|
||||||
|
CleanupExpired(ctx context.Context, expireAfter time.Duration) (removed int, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// ProxyClient interface(api-server 端 → remote-proxy 內部 HTTP)
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// ProxyClient 是 api-server 端呼叫 remote-proxy internal HTTP API 的抽象。
|
||||||
|
//
|
||||||
|
// 雛形只定義 interface;實際 HTTP 呼叫留給 B4 的 proxy_client.go。
|
||||||
|
// Store 的 ProxyClientStore 實作會 delegate 到此 client。
|
||||||
|
//
|
||||||
|
// 相關 internal HTTP 端點(見 tunnel.md §7.1):
|
||||||
|
// - GET /internal/session/:token → GetSession / Exists
|
||||||
|
// - POST /internal/forward/http?token=… → ForwardHTTP
|
||||||
|
// - GET /internal/forward/ws?token=… → ForwardWebSocket
|
||||||
|
// - POST /internal/session/:token/close → CloseSession
|
||||||
|
type ProxyClient interface {
|
||||||
|
// GetSession 從 remote-proxy 查詢指定 token 的 session summary;
|
||||||
|
// 不存在回 ErrSessionNotFound。
|
||||||
|
GetSession(ctx context.Context, token string) (*Summary, error)
|
||||||
|
|
||||||
|
// ListSessions 列出 remote-proxy 當前所有 active session。
|
||||||
|
ListSessions(ctx context.Context) ([]*Summary, error)
|
||||||
|
|
||||||
|
// CloseSession 主動要求 remote-proxy 關閉指定 session(管理動作)。
|
||||||
|
CloseSession(ctx context.Context, token string) error
|
||||||
|
}
|
||||||
352
visionA-backend/internal/storage/localfs.go
Normal file
352
visionA-backend/internal/storage/localfs.go
Normal file
@ -0,0 +1,352 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/fs"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LocalFSStore 是以本地 filesystem 為後端的 Store 實作(Phase 0 雛形)。
|
||||||
|
//
|
||||||
|
// 特性:
|
||||||
|
// - 檔案存於 root + key 組成的路徑下(root 預設 ./data/storage)
|
||||||
|
// - meta 存為 sidecar 檔:`{path}.meta.json`(雛形簡化;S3 原生支援 metadata)
|
||||||
|
// - Presigned URL 使用 HMAC-SHA256 簽名,api-server 的 /storage handler 驗證
|
||||||
|
//
|
||||||
|
// Phase 1:S3Store 會實作同 interface 取代之。
|
||||||
|
type LocalFSStore struct {
|
||||||
|
root string
|
||||||
|
baseURL string
|
||||||
|
signer *Signer
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLocalFSStore 建立一個 LocalFSStore。
|
||||||
|
//
|
||||||
|
// root 為儲存根目錄(不存在會自動建立);baseURL 用於 presigned URL 前綴;
|
||||||
|
// signingSecret 為 HMAC 簽名 secret(生產環境必須由 env 提供,不可使用預設值)。
|
||||||
|
func NewLocalFSStore(root, baseURL, signingSecret string) (*LocalFSStore, error) {
|
||||||
|
if root == "" {
|
||||||
|
return nil, errors.New("storage: root must not be empty")
|
||||||
|
}
|
||||||
|
absRoot, err := filepath.Abs(root)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("storage: resolve root abs path: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(absRoot, 0o755); err != nil {
|
||||||
|
return nil, fmt.Errorf("storage: mkdir root: %w", err)
|
||||||
|
}
|
||||||
|
if signingSecret == "" {
|
||||||
|
signingSecret = "dev-signing-secret-do-not-use-in-prod"
|
||||||
|
}
|
||||||
|
return &LocalFSStore{
|
||||||
|
root: absRoot,
|
||||||
|
baseURL: strings.TrimRight(baseURL, "/"),
|
||||||
|
signer: NewSigner([]byte(signingSecret)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveKey 將 key 轉成絕對路徑,並驗證未逃出 root(防止 path traversal)。
|
||||||
|
//
|
||||||
|
// 允許 key 為空字串(回 root 本身,供 List 使用全量掃描)。
|
||||||
|
func (s *LocalFSStore) resolveKey(key string) (string, error) {
|
||||||
|
if strings.Contains(key, "\x00") {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
// 明確拒絕任何包含 ".." 的 path segment — 防止絕對路徑逃出 root
|
||||||
|
// (即使 filepath.Clean 會 normalize,保險起見先在此層阻擋)。
|
||||||
|
if containsParentDir(key) {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
// 拒絕絕對路徑開頭
|
||||||
|
if strings.HasPrefix(key, "/") || strings.HasPrefix(key, string(filepath.Separator)) {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
|
||||||
|
if key == "" {
|
||||||
|
return s.root, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
full := filepath.Join(s.root, key)
|
||||||
|
absFull, err := filepath.Abs(full)
|
||||||
|
if err != nil {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
// 確保最終路徑仍在 root 底下(雙重保險)
|
||||||
|
if absFull != s.root && !strings.HasPrefix(absFull, s.root+string(filepath.Separator)) {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
return absFull, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// containsParentDir 回報 key 是否含有 ".." segment(用 / 與 OS separator 兩種分隔符)。
|
||||||
|
func containsParentDir(key string) bool {
|
||||||
|
for _, sep := range []string{"/", string(filepath.Separator)} {
|
||||||
|
for _, seg := range strings.Split(key, sep) {
|
||||||
|
if seg == ".." {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Put 寫入 object;路徑不存在會自動建立父目錄。
|
||||||
|
func (s *LocalFSStore) Put(ctx context.Context, key string, r io.Reader, size int64, meta map[string]string) error {
|
||||||
|
if key == "" {
|
||||||
|
return ErrInvalidKey
|
||||||
|
}
|
||||||
|
fullPath, err := s.resolveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(fullPath), 0o755); err != nil {
|
||||||
|
return fmt.Errorf("storage: mkdir: %w", err)
|
||||||
|
}
|
||||||
|
f, err := os.Create(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("storage: create file: %w", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if _, err := io.Copy(f, r); err != nil {
|
||||||
|
return fmt.Errorf("storage: write file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 雛形暫不寫 meta sidecar(Phase 1 按需要實作)。
|
||||||
|
// 若要寫,對齊 storage.md §3:{path}.meta.json
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get 開啟一個 object 讀取 reader 並回傳 metadata。
|
||||||
|
func (s *LocalFSStore) Get(ctx context.Context, key string) (io.ReadCloser, *Object, error) {
|
||||||
|
if key == "" {
|
||||||
|
return nil, nil, ErrInvalidKey
|
||||||
|
}
|
||||||
|
fullPath, err := s.resolveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
info, err := os.Stat(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil, nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, nil, fmt.Errorf("storage: stat: %w", err)
|
||||||
|
}
|
||||||
|
f, err := os.Open(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("storage: open file: %w", err)
|
||||||
|
}
|
||||||
|
obj := &Object{
|
||||||
|
Key: key,
|
||||||
|
Size: info.Size(),
|
||||||
|
ContentType: "application/octet-stream", // 雛形預設;Phase 1 讀 sidecar
|
||||||
|
LastModified: info.ModTime().UTC(),
|
||||||
|
}
|
||||||
|
return f, obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stat 回傳 metadata,不開啟內容。
|
||||||
|
func (s *LocalFSStore) Stat(ctx context.Context, key string) (*Object, error) {
|
||||||
|
if key == "" {
|
||||||
|
return nil, ErrInvalidKey
|
||||||
|
}
|
||||||
|
fullPath, err := s.resolveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
info, err := os.Stat(fullPath)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil, ErrNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("storage: stat: %w", err)
|
||||||
|
}
|
||||||
|
return &Object{
|
||||||
|
Key: key,
|
||||||
|
Size: info.Size(),
|
||||||
|
ContentType: "application/octet-stream",
|
||||||
|
LastModified: info.ModTime().UTC(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exists 判斷 object 是否存在。
|
||||||
|
//
|
||||||
|
// 語意對齊 storage.md §1 的 ObjectStorage.Exists:
|
||||||
|
// - 檔案存在 → (true, nil)
|
||||||
|
// - 檔案不存在 → (false, nil)(非 error)
|
||||||
|
// - 其他 IO 錯誤 → (false, err)
|
||||||
|
func (s *LocalFSStore) Exists(ctx context.Context, key string) (bool, error) {
|
||||||
|
if key == "" {
|
||||||
|
return false, ErrInvalidKey
|
||||||
|
}
|
||||||
|
fullPath, err := s.resolveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
_, err = os.Stat(fullPath)
|
||||||
|
if err == nil {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, fmt.Errorf("storage: stat: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete 刪除 object;不存在視為成功(no-op)。
|
||||||
|
func (s *LocalFSStore) Delete(ctx context.Context, key string) error {
|
||||||
|
if key == "" {
|
||||||
|
return ErrInvalidKey
|
||||||
|
}
|
||||||
|
fullPath, err := s.resolveKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := os.Remove(fullPath); err != nil {
|
||||||
|
if errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("storage: remove: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// List 列出指定 prefix 下的所有 object(遞迴)。
|
||||||
|
//
|
||||||
|
// 雛形實作簡化:以 filepath.Walk 掃描 root + prefix 資料夾。
|
||||||
|
// Phase 1 的 S3Store 可直接用原生 ListObjects API。
|
||||||
|
func (s *LocalFSStore) List(ctx context.Context, prefix string) ([]*Object, error) {
|
||||||
|
// prefix 允許為空(列全部)。
|
||||||
|
base, err := s.resolveKey(prefix)
|
||||||
|
if err != nil {
|
||||||
|
// prefix 不合法
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// 如果 prefix 指向的路徑不存在,回空 list(不是錯)。
|
||||||
|
if fi, statErr := os.Stat(base); statErr != nil || !fi.IsDir() {
|
||||||
|
if statErr != nil && errors.Is(statErr, fs.ErrNotExist) {
|
||||||
|
return []*Object{}, nil
|
||||||
|
}
|
||||||
|
// 若 prefix 指向檔案 → 回單筆
|
||||||
|
if statErr == nil && !fi.IsDir() {
|
||||||
|
rel, _ := filepath.Rel(s.root, base)
|
||||||
|
return []*Object{{
|
||||||
|
Key: filepath.ToSlash(rel),
|
||||||
|
Size: fi.Size(),
|
||||||
|
ContentType: "application/octet-stream",
|
||||||
|
LastModified: fi.ModTime().UTC(),
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make([]*Object, 0)
|
||||||
|
err = filepath.Walk(base, func(path string, info os.FileInfo, walkErr error) error {
|
||||||
|
if walkErr != nil {
|
||||||
|
return walkErr
|
||||||
|
}
|
||||||
|
if info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rel, relErr := filepath.Rel(s.root, path)
|
||||||
|
if relErr != nil {
|
||||||
|
return relErr
|
||||||
|
}
|
||||||
|
out = append(out, &Object{
|
||||||
|
Key: filepath.ToSlash(rel),
|
||||||
|
Size: info.Size(),
|
||||||
|
ContentType: "application/octet-stream",
|
||||||
|
LastModified: info.ModTime().UTC(),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||||
|
return nil, fmt.Errorf("storage: walk: %w", err)
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PresignedGetURL 產生一個附 HMAC 簽名的下載 URL(雛形 LocalFS 實作)。
|
||||||
|
//
|
||||||
|
// 格式:{baseURL}/{escaped-key}?expires={unix}&signature={base64url-hmac}
|
||||||
|
// api-server 的 /storage/*filepath handler 負責驗證(見 storage.md §3.1)。
|
||||||
|
func (s *LocalFSStore) PresignedGetURL(ctx context.Context, key string, ttl time.Duration) (string, error) {
|
||||||
|
return s.presignedURL("GET", key, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PresignedPutURL 產生一個附簽名的上傳 URL。
|
||||||
|
func (s *LocalFSStore) PresignedPutURL(ctx context.Context, key string, ttl time.Duration) (string, error) {
|
||||||
|
return s.presignedURL("PUT", key, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *LocalFSStore) presignedURL(method, key string, ttl time.Duration) (string, error) {
|
||||||
|
if key == "" {
|
||||||
|
return "", ErrInvalidKey
|
||||||
|
}
|
||||||
|
if _, err := s.resolveKey(key); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
expiresAt := time.Now().UTC().Add(ttl).Unix()
|
||||||
|
sig := s.signer.Sign(fmt.Sprintf("%s\n%s\n%d", method, key, expiresAt))
|
||||||
|
|
||||||
|
escaped := url.PathEscape(key)
|
||||||
|
u := fmt.Sprintf("%s/%s?expires=%d&signature=%s", s.baseURL, escaped, expiresAt, sig)
|
||||||
|
if method == "PUT" {
|
||||||
|
u += "&mode=put"
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySignature 供 api-server 的 /storage handler 呼叫(LocalFS 專用)。
|
||||||
|
//
|
||||||
|
// 參數:
|
||||||
|
// - method:HTTP method("GET" / "PUT")
|
||||||
|
// - key:storage key(已 urldecode 過)
|
||||||
|
// - expires:URL 裡的 expires 參數
|
||||||
|
// - signature:URL 裡的 signature 參數(base64url)
|
||||||
|
//
|
||||||
|
// 回 nil 表驗證通過;否則回 ErrInvalidSignature(或過期)。
|
||||||
|
func (s *LocalFSStore) VerifySignature(method, key string, expires int64, signature string) error {
|
||||||
|
if time.Now().UTC().Unix() > expires {
|
||||||
|
return ErrInvalidSignature
|
||||||
|
}
|
||||||
|
expected := s.signer.Sign(fmt.Sprintf("%s\n%s\n%d", method, key, expires))
|
||||||
|
if !hmac.Equal([]byte(expected), []byte(signature)) {
|
||||||
|
return ErrInvalidSignature
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Signer — HMAC-SHA256 for LocalFS presigned URL
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Signer 封裝 HMAC-SHA256 簽名流程;輸出為 base64url(無 padding)。
|
||||||
|
type Signer struct {
|
||||||
|
secret []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSigner 建立簽名器;secret 應具備足夠長度(建議 >= 32 bytes)。
|
||||||
|
func NewSigner(secret []byte) *Signer {
|
||||||
|
return &Signer{secret: secret}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign 對 payload 產出 base64url-nopad 的 HMAC-SHA256 簽名。
|
||||||
|
func (s *Signer) Sign(payload string) string {
|
||||||
|
mac := hmac.New(sha256.New, s.secret)
|
||||||
|
mac.Write([]byte(payload))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 編譯時檢查:確保 LocalFSStore 實作 Store。
|
||||||
|
var _ Store = (*LocalFSStore)(nil)
|
||||||
174
visionA-backend/internal/storage/localfs_test.go
Normal file
174
visionA-backend/internal/storage/localfs_test.go
Normal file
@ -0,0 +1,174 @@
|
|||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newTestStore(t *testing.T) *LocalFSStore {
|
||||||
|
t.Helper()
|
||||||
|
s, err := NewLocalFSStore(t.TempDir(), "http://localhost:3001/storage", "test-secret")
|
||||||
|
require.NoError(t, err)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_PutGetStat(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
payload := []byte("hello-visiona")
|
||||||
|
key := "models/user-1/m1.nef"
|
||||||
|
|
||||||
|
require.NoError(t, s.Put(ctx, key, bytes.NewReader(payload), int64(len(payload)), nil))
|
||||||
|
|
||||||
|
// Stat
|
||||||
|
info, err := s.Stat(ctx, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(len(payload)), info.Size)
|
||||||
|
assert.Equal(t, key, info.Key)
|
||||||
|
|
||||||
|
// Get
|
||||||
|
rc, obj, err := s.Get(ctx, key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rc.Close()
|
||||||
|
got, err := io.ReadAll(rc)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, payload, got)
|
||||||
|
assert.Equal(t, int64(len(payload)), obj.Size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_Exists(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
ok, err := s.Exists(ctx, "nope.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, ok, "不存在應回 (false, nil)")
|
||||||
|
|
||||||
|
require.NoError(t, s.Put(ctx, "a.txt", strings.NewReader("x"), 1, nil))
|
||||||
|
|
||||||
|
ok, err = s.Exists(ctx, "a.txt")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, ok)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_Get_NotFound(t *testing.T) {
|
||||||
|
s := newTestStore(t)
|
||||||
|
_, _, err := s.Get(context.Background(), "missing.txt")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_Stat_NotFound(t *testing.T) {
|
||||||
|
s := newTestStore(t)
|
||||||
|
_, err := s.Stat(context.Background(), "missing.txt")
|
||||||
|
assert.ErrorIs(t, err, ErrNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_Delete(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
require.NoError(t, s.Put(ctx, "tmp.txt", strings.NewReader("x"), 1, nil))
|
||||||
|
require.NoError(t, s.Delete(ctx, "tmp.txt"))
|
||||||
|
|
||||||
|
ok, _ := s.Exists(ctx, "tmp.txt")
|
||||||
|
assert.False(t, ok)
|
||||||
|
|
||||||
|
// 刪除不存在的 key 不應回錯
|
||||||
|
assert.NoError(t, s.Delete(ctx, "never.txt"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_List(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
require.NoError(t, s.Put(ctx, "models/u1/a.nef", strings.NewReader("A"), 1, nil))
|
||||||
|
require.NoError(t, s.Put(ctx, "models/u1/b.nef", strings.NewReader("B"), 1, nil))
|
||||||
|
require.NoError(t, s.Put(ctx, "models/u2/c.nef", strings.NewReader("C"), 1, nil))
|
||||||
|
|
||||||
|
listU1, err := s.List(ctx, "models/u1")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listU1, 2)
|
||||||
|
|
||||||
|
listAll, err := s.List(ctx, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, listAll, 3)
|
||||||
|
|
||||||
|
listEmpty, err := s.List(ctx, "not-exist-prefix")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, listEmpty)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_PathTraversal_Rejected(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
// 嘗試逃出 root
|
||||||
|
err := s.Put(ctx, "../../etc/passwd", strings.NewReader("pwned"), 5, nil)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidKey)
|
||||||
|
|
||||||
|
_, _, err = s.Get(ctx, "../secret.txt")
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidKey)
|
||||||
|
|
||||||
|
_, err = s.Stat(ctx, "../secret.txt")
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidKey)
|
||||||
|
|
||||||
|
// 空 key
|
||||||
|
err = s.Put(ctx, "", strings.NewReader("x"), 1, nil)
|
||||||
|
assert.ErrorIs(t, err, ErrInvalidKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_PresignedGetURL_AndVerify(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
s := newTestStore(t)
|
||||||
|
|
||||||
|
key := "models/u1/m.nef"
|
||||||
|
require.NoError(t, s.Put(ctx, key, strings.NewReader("X"), 1, nil))
|
||||||
|
|
||||||
|
u, err := s.PresignedGetURL(ctx, key, 5*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, u, "http://localhost:3001/storage/")
|
||||||
|
assert.Contains(t, u, "expires=")
|
||||||
|
assert.Contains(t, u, "signature=")
|
||||||
|
|
||||||
|
// 解析並驗證
|
||||||
|
parsed, err := url.Parse(u)
|
||||||
|
require.NoError(t, err)
|
||||||
|
expires, err := strconv.ParseInt(parsed.Query().Get("expires"), 10, 64)
|
||||||
|
require.NoError(t, err)
|
||||||
|
sig := parsed.Query().Get("signature")
|
||||||
|
|
||||||
|
assert.NoError(t, s.VerifySignature("GET", key, expires, sig))
|
||||||
|
|
||||||
|
// 簽名錯誤
|
||||||
|
assert.ErrorIs(t, s.VerifySignature("GET", key, expires, "tampered"), ErrInvalidSignature)
|
||||||
|
|
||||||
|
// 已過期
|
||||||
|
assert.ErrorIs(t,
|
||||||
|
s.VerifySignature("GET", key, time.Now().Add(-1*time.Hour).Unix(), sig),
|
||||||
|
ErrInvalidSignature)
|
||||||
|
|
||||||
|
// method 不符
|
||||||
|
assert.ErrorIs(t, s.VerifySignature("PUT", key, expires, sig), ErrInvalidSignature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalFSStore_PresignedPutURL(t *testing.T) {
|
||||||
|
s := newTestStore(t)
|
||||||
|
u, err := s.PresignedPutURL(context.Background(), "models/u1/new.nef", 10*time.Minute)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Contains(t, u, "mode=put")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewLocalFSStore_EmptyRoot(t *testing.T) {
|
||||||
|
_, err := NewLocalFSStore("", "", "")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
95
visionA-backend/internal/storage/storage.go
Normal file
95
visionA-backend/internal/storage/storage.go
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
// Package storage 定義物件儲存介面與 LocalFS 實作。
|
||||||
|
//
|
||||||
|
// 對齊 storage.md §1 與 PRD interface-contracts.md §8.4:
|
||||||
|
// - 雛形使用 LocalFSStore(檔案系統)+ 假 presigned URL(HMAC 簽名)
|
||||||
|
// - Phase 1 新增 S3Store(同 interface),業務邏輯不用動
|
||||||
|
//
|
||||||
|
// Key 命名規範見 storage.md §2(例:`models/{user_id}/{model_id}.nef`)。
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Errors
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrNotFound 表示指定 key 的 object 不存在。
|
||||||
|
ErrNotFound = errors.New("storage: object not found")
|
||||||
|
|
||||||
|
// ErrAlreadyExists 表示 object 已存在且 Put 被要求不覆蓋(Phase 1)。
|
||||||
|
ErrAlreadyExists = errors.New("storage: object already exists")
|
||||||
|
|
||||||
|
// ErrInvalidKey 表示 key 含非法字元(例:包含 "..", 嘗試 path traversal)。
|
||||||
|
ErrInvalidKey = errors.New("storage: invalid key")
|
||||||
|
|
||||||
|
// ErrInvalidSignature 表示 presigned URL 簽名錯誤或過期。
|
||||||
|
ErrInvalidSignature = errors.New("storage: invalid or expired signature")
|
||||||
|
)
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Types
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Object 是儲存物件的描述(metadata;不含實際內容)。
|
||||||
|
//
|
||||||
|
// 對齊 storage.md §1 的 ObjectInfo。
|
||||||
|
type Object struct {
|
||||||
|
Key string `json:"key"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
ContentType string `json:"contentType"`
|
||||||
|
LastModified time.Time `json:"lastModified"`
|
||||||
|
ETag string `json:"etag,omitempty"`
|
||||||
|
Metadata map[string]string `json:"metadata,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==========================================================================
|
||||||
|
// Store interface
|
||||||
|
// ==========================================================================
|
||||||
|
|
||||||
|
// Store 是物件儲存的抽象。
|
||||||
|
//
|
||||||
|
// 實作必須:
|
||||||
|
// - 並發安全(多 goroutine 同時 Put / Get 不 panic)
|
||||||
|
// - Key 驗證(防止 path traversal)
|
||||||
|
// - 語意對齊:Exists 回 (false, nil) 表不存在,其他錯誤走 err
|
||||||
|
type Store interface {
|
||||||
|
// Put 上傳一個 object;若已存在則覆蓋。
|
||||||
|
// size 為預期大小(bytes);實作可用於早期檢查或配額控制。
|
||||||
|
// meta 可為 nil(無額外 metadata)。
|
||||||
|
Put(ctx context.Context, key string, r io.Reader, size int64, meta map[string]string) error
|
||||||
|
|
||||||
|
// Get 下載一個 object;caller 必須 Close() reader。
|
||||||
|
// 不存在回 ErrNotFound。
|
||||||
|
Get(ctx context.Context, key string) (io.ReadCloser, *Object, error)
|
||||||
|
|
||||||
|
// Stat 取得 object 的 metadata(不下載內容)。
|
||||||
|
// 不存在回 ErrNotFound。
|
||||||
|
Stat(ctx context.Context, key string) (*Object, error)
|
||||||
|
|
||||||
|
// Exists 判斷 key 是否存在(Minor-4 / PRD §8.4 要求)。
|
||||||
|
// 語意:true = 存在可用;false = 不存在(非 error)。
|
||||||
|
// 其他錯誤(權限 / IO)回 (false, err)。
|
||||||
|
Exists(ctx context.Context, key string) (bool, error)
|
||||||
|
|
||||||
|
// Delete 刪除 object;不存在為 no-op,不回 error。
|
||||||
|
Delete(ctx context.Context, key string) error
|
||||||
|
|
||||||
|
// List 列出指定 prefix 下的所有 object。
|
||||||
|
List(ctx context.Context, prefix string) ([]*Object, error)
|
||||||
|
|
||||||
|
// PresignedGetURL 產生一個限時下載 URL。
|
||||||
|
// 對 LocalFS:回傳 baseURL + /key?expires=&signature=(由 api-server 驗證)
|
||||||
|
// 對 S3:回傳原生 AWS presigned URL
|
||||||
|
PresignedGetURL(ctx context.Context, key string, ttl time.Duration) (string, error)
|
||||||
|
|
||||||
|
// PresignedPutURL 產生一個限時上傳 URL。
|
||||||
|
// 對 LocalFS:回傳 baseURL + /key?expires=&signature=&mode=put
|
||||||
|
// 對 S3:回傳原生 AWS presigned PUT URL
|
||||||
|
PresignedPutURL(ctx context.Context, key string, ttl time.Duration) (string, error)
|
||||||
|
}
|
||||||
175
visionA-backend/internal/usersession/cookie.go
Normal file
175
visionA-backend/internal/usersession/cookie.go
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
package usersession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CookieConfig 集中描述 cookie 的所有屬性,避免 helper 函式裡塞太多參數。
|
||||||
|
//
|
||||||
|
// 對齊 oidc-tdd.md §5.1:
|
||||||
|
//
|
||||||
|
// Name = "visiona_session"
|
||||||
|
// Domain = ".visiona.cloud"(prod)/ ""(dev)
|
||||||
|
// Path = "/"
|
||||||
|
// Secure = true(prod HTTPS)/ false(dev HTTP)
|
||||||
|
// HTTPOnly = true(永遠)
|
||||||
|
// SameSite = http.SameSiteLaxMode
|
||||||
|
// MaxAge = 86400(雛形 24h;TDD §5.1 是 7d,雛形先以任務指定值為準)
|
||||||
|
// SigningKey = ≥ 32 bytes 隨機(HMAC-SHA256)
|
||||||
|
type CookieConfig struct {
|
||||||
|
Name string // cookie 名稱;空字串會 fallback 到 DefaultCookieName
|
||||||
|
Domain string // production 設定(如 ".visiona.cloud"),dev 留空
|
||||||
|
Path string // cookie 範圍;空字串會 fallback 到 "/"
|
||||||
|
Secure bool // 是否要求 HTTPS(dev=false, prod=true)
|
||||||
|
HTTPOnly bool // 是否禁止 JS 讀取(永遠應為 true)
|
||||||
|
SameSite http.SameSite // 預設 http.SameSiteLaxMode
|
||||||
|
MaxAge int // cookie 存活秒數;0 = session cookie;負值 = 立即刪除
|
||||||
|
|
||||||
|
// SigningKey 是 HMAC-SHA256 的金鑰,**必填**,至少 32 bytes 才安全(caller 自行確認)。
|
||||||
|
// 預設應由 env var VISIONA_SESSION_SECRET 注入,在 process startup 階段檢查長度。
|
||||||
|
SigningKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultCookieName 與 oidc-tdd.md §5.1 對齊。
|
||||||
|
const DefaultCookieName = "visiona_session"
|
||||||
|
|
||||||
|
// validate 檢查 CookieConfig 必填欄位。
|
||||||
|
//
|
||||||
|
// 不檢查 SigningKey 長度(由 caller 在 startup 階段確保 ≥ 32 bytes,
|
||||||
|
// 此處不重複檢查避免每次 read/write 都做一次)。
|
||||||
|
func (c CookieConfig) validate() error {
|
||||||
|
if len(c.SigningKey) == 0 {
|
||||||
|
return ErrInvalidConfig
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvedName 回傳 c.Name 或 DefaultCookieName。
|
||||||
|
func (c CookieConfig) resolvedName() string {
|
||||||
|
if c.Name == "" {
|
||||||
|
return DefaultCookieName
|
||||||
|
}
|
||||||
|
return c.Name
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvedPath 回傳 c.Path 或 "/"。
|
||||||
|
func (c CookieConfig) resolvedPath() string {
|
||||||
|
if c.Path == "" {
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
return c.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolvedSameSite 回傳 c.SameSite 或預設 Lax。
|
||||||
|
func (c CookieConfig) resolvedSameSite() http.SameSite {
|
||||||
|
if c.SameSite == 0 {
|
||||||
|
return http.SameSiteLaxMode
|
||||||
|
}
|
||||||
|
return c.SameSite
|
||||||
|
}
|
||||||
|
|
||||||
|
// signSessionID 用 HMAC-SHA256 產生簽章,回傳 base64url 編碼。
|
||||||
|
func signSessionID(sessionID string, key []byte) string {
|
||||||
|
h := hmac.New(sha256.New, key)
|
||||||
|
h.Write([]byte(sessionID))
|
||||||
|
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeCookieValue 將 sessionID 與 HMAC 簽章組成 cookie value。
|
||||||
|
//
|
||||||
|
// Format:<sessionID>.<base64url(HMAC-SHA256(SigningKey, sessionID))>
|
||||||
|
//
|
||||||
|
// sessionID 必須是 base64url(NewInMemoryStore.Create 產生的格式),不可含 "."(會撞 separator)。
|
||||||
|
func EncodeCookieValue(sessionID string, key []byte) string {
|
||||||
|
return sessionID + "." + signSessionID(sessionID, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeCookieValue 解析 cookie value,驗 HMAC 後回傳 sessionID。
|
||||||
|
//
|
||||||
|
// 任何 parse / sig 失敗都統一回 ErrInvalidCookie 或 ErrSignatureMismatch,
|
||||||
|
// 避免攻擊者從錯誤訊息推斷 SigningKey 結構。
|
||||||
|
func DecodeCookieValue(value string, key []byte) (string, error) {
|
||||||
|
if value == "" {
|
||||||
|
return "", ErrInvalidCookie
|
||||||
|
}
|
||||||
|
// SplitN(2):sessionID 內絕對無 "."(base64url 字元集為 A-Z a-z 0-9 - _),所以唯一的 "." 就是 separator。
|
||||||
|
parts := strings.SplitN(value, ".", 2)
|
||||||
|
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||||
|
return "", ErrInvalidCookie
|
||||||
|
}
|
||||||
|
sessionID, providedSig := parts[0], parts[1]
|
||||||
|
|
||||||
|
expectedSig := signSessionID(sessionID, key)
|
||||||
|
// 用常數時間比較避免 timing attack。
|
||||||
|
if subtle.ConstantTimeCompare([]byte(providedSig), []byte(expectedSig)) != 1 {
|
||||||
|
return "", ErrSignatureMismatch
|
||||||
|
}
|
||||||
|
return sessionID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteCookie 將 sessionID 簽章後寫入 Set-Cookie header。
|
||||||
|
//
|
||||||
|
// 失敗(cfg.SigningKey 缺)回 ErrInvalidConfig。
|
||||||
|
// 為了讓 caller 在 handler 中乾淨呼叫,不對 w 做任何 status / body 操作。
|
||||||
|
func WriteCookie(w http.ResponseWriter, cfg CookieConfig, sessionID string) error {
|
||||||
|
if err := cfg.validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if sessionID == "" {
|
||||||
|
return ErrInvalidCookie
|
||||||
|
}
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: cfg.resolvedName(),
|
||||||
|
Value: EncodeCookieValue(sessionID, cfg.SigningKey),
|
||||||
|
Path: cfg.resolvedPath(),
|
||||||
|
Domain: cfg.Domain,
|
||||||
|
MaxAge: cfg.MaxAge,
|
||||||
|
Secure: cfg.Secure,
|
||||||
|
HttpOnly: cfg.HTTPOnly,
|
||||||
|
SameSite: cfg.resolvedSameSite(),
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadCookie 從 request 取出 cookie,驗簽,回傳 sessionID。
|
||||||
|
//
|
||||||
|
// 找不到 cookie 回 (空, false),**不**當成 error。
|
||||||
|
// cookie 存在但 parse / sig 失敗回 (空, false)(同樣不揭露細節給 caller,避免被當成 oracle)。
|
||||||
|
//
|
||||||
|
// 內部錯誤(cfg.SigningKey 缺)回 (空, false),但 caller 應在 startup 時就避免。
|
||||||
|
func ReadCookie(r *http.Request, cfg CookieConfig) (sessionID string, ok bool) {
|
||||||
|
if cfg.validate() != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
c, err := r.Cookie(cfg.resolvedName())
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
sid, err := DecodeCookieValue(c.Value, cfg.SigningKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return sid, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearCookie 寫一個過期的同名 cookie,瀏覽器會刪除它。
|
||||||
|
//
|
||||||
|
// 必須使用與設定時「相同的 Name / Path / Domain」否則瀏覽器不會刪到正確的那一份
|
||||||
|
// (RFC 6265)。
|
||||||
|
func ClearCookie(w http.ResponseWriter, cfg CookieConfig) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: cfg.resolvedName(),
|
||||||
|
Value: "",
|
||||||
|
Path: cfg.resolvedPath(),
|
||||||
|
Domain: cfg.Domain,
|
||||||
|
MaxAge: -1, // 過期,立即刪除
|
||||||
|
Secure: cfg.Secure,
|
||||||
|
HttpOnly: cfg.HTTPOnly,
|
||||||
|
SameSite: cfg.resolvedSameSite(),
|
||||||
|
})
|
||||||
|
}
|
||||||
250
visionA-backend/internal/usersession/cookie_test.go
Normal file
250
visionA-backend/internal/usersession/cookie_test.go
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
package usersession
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testKey 是測試用的 32 byte HMAC key。
|
||||||
|
var testKey = []byte("test-key-test-key-test-key-1234!") // 32 bytes
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
// Encode / Decode roundtrip
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestEncodeDecode_Roundtrip(t *testing.T) {
|
||||||
|
sid := "abc123XYZ_-test"
|
||||||
|
encoded := EncodeCookieValue(sid, testKey)
|
||||||
|
if !strings.Contains(encoded, ".") {
|
||||||
|
t.Fatalf("encoded should contain separator '.', got %q", encoded)
|
||||||
|
}
|
||||||
|
got, err := DecodeCookieValue(encoded, testKey)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Decode: %v", err)
|
||||||
|
}
|
||||||
|
if got != sid {
|
||||||
|
t.Fatalf("roundtrip mismatch: got %q want %q", got, sid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecode_TamperedSessionID(t *testing.T) {
|
||||||
|
encoded := EncodeCookieValue("realsid", testKey)
|
||||||
|
parts := strings.SplitN(encoded, ".", 2)
|
||||||
|
tampered := "fakesid." + parts[1]
|
||||||
|
_, err := DecodeCookieValue(tampered, testKey)
|
||||||
|
if !errors.Is(err, ErrSignatureMismatch) {
|
||||||
|
t.Fatalf("expected ErrSignatureMismatch when sessionID tampered, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecode_TamperedSignature(t *testing.T) {
|
||||||
|
encoded := EncodeCookieValue("realsid", testKey)
|
||||||
|
parts := strings.SplitN(encoded, ".", 2)
|
||||||
|
// 換個簽章(不同長度的 base64url 也會失敗)
|
||||||
|
tampered := parts[0] + ".AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA"
|
||||||
|
_, err := DecodeCookieValue(tampered, testKey)
|
||||||
|
if !errors.Is(err, ErrSignatureMismatch) {
|
||||||
|
t.Fatalf("expected ErrSignatureMismatch, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecode_DifferentKeyFails(t *testing.T) {
|
||||||
|
encoded := EncodeCookieValue("sid", testKey)
|
||||||
|
other := []byte("other-key-other-key-other-key-1!")
|
||||||
|
_, err := DecodeCookieValue(encoded, other)
|
||||||
|
if !errors.Is(err, ErrSignatureMismatch) {
|
||||||
|
t.Fatalf("expected ErrSignatureMismatch with different key, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDecode_MalformedValues(t *testing.T) {
|
||||||
|
cases := map[string]string{
|
||||||
|
"empty": "",
|
||||||
|
"no separator": "noseparator",
|
||||||
|
"only sep": ".",
|
||||||
|
"empty sid": ".sigonly",
|
||||||
|
"empty sig": "sidonly.",
|
||||||
|
}
|
||||||
|
for name, val := range cases {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
_, err := DecodeCookieValue(val, testKey)
|
||||||
|
if !errors.Is(err, ErrInvalidCookie) {
|
||||||
|
t.Fatalf("%s: expected ErrInvalidCookie, got %v", name, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
// WriteCookie / ReadCookie / ClearCookie via httptest
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func newCookieCfg() CookieConfig {
|
||||||
|
return CookieConfig{
|
||||||
|
Name: DefaultCookieName,
|
||||||
|
Path: "/",
|
||||||
|
HTTPOnly: true,
|
||||||
|
Secure: false,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
MaxAge: 86400,
|
||||||
|
SigningKey: testKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteAndReadCookie_Roundtrip(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
sid := "session-abc-123"
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
if err := WriteCookie(w, cfg, sid); err != nil {
|
||||||
|
t.Fatalf("WriteCookie: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 從 response 取出 Set-Cookie,做成 request cookie 模擬 browser 回傳
|
||||||
|
resp := w.Result()
|
||||||
|
cookies := resp.Cookies()
|
||||||
|
if len(cookies) != 1 {
|
||||||
|
t.Fatalf("expected 1 Set-Cookie, got %d", len(cookies))
|
||||||
|
}
|
||||||
|
c := cookies[0]
|
||||||
|
if c.Name != DefaultCookieName {
|
||||||
|
t.Fatalf("cookie name: got %q want %q", c.Name, DefaultCookieName)
|
||||||
|
}
|
||||||
|
if !c.HttpOnly {
|
||||||
|
t.Fatalf("HttpOnly should be true")
|
||||||
|
}
|
||||||
|
if c.SameSite != http.SameSiteLaxMode {
|
||||||
|
t.Fatalf("SameSite should be Lax")
|
||||||
|
}
|
||||||
|
if c.MaxAge != 86400 {
|
||||||
|
t.Fatalf("MaxAge: got %d want 86400", c.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.AddCookie(c)
|
||||||
|
got, ok := ReadCookie(r, cfg)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("ReadCookie: ok=false")
|
||||||
|
}
|
||||||
|
if got != sid {
|
||||||
|
t.Fatalf("ReadCookie: got %q want %q", got, sid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadCookie_NoCookie(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
if _, ok := ReadCookie(r, cfg); ok {
|
||||||
|
t.Fatalf("ReadCookie should return ok=false when no cookie present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadCookie_TamperedValue(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
// 故意放一個簽章錯的 cookie
|
||||||
|
r.AddCookie(&http.Cookie{Name: cfg.Name, Value: "tampered.sig"})
|
||||||
|
if _, ok := ReadCookie(r, cfg); ok {
|
||||||
|
t.Fatalf("ReadCookie should return ok=false for tampered cookie")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCookie_SetsExpiration(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
ClearCookie(w, cfg)
|
||||||
|
cookies := w.Result().Cookies()
|
||||||
|
if len(cookies) != 1 {
|
||||||
|
t.Fatalf("expected 1 Set-Cookie, got %d", len(cookies))
|
||||||
|
}
|
||||||
|
c := cookies[0]
|
||||||
|
if c.MaxAge >= 0 {
|
||||||
|
t.Fatalf("ClearCookie MaxAge should be < 0, got %d", c.MaxAge)
|
||||||
|
}
|
||||||
|
if c.Value != "" {
|
||||||
|
t.Fatalf("ClearCookie value should be empty, got %q", c.Value)
|
||||||
|
}
|
||||||
|
if c.Name != cfg.Name || c.Path != cfg.Path {
|
||||||
|
t.Fatalf("ClearCookie must use same Name/Path; got name=%q path=%q", c.Name, c.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearCookie_BrowserCannotRead(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
|
||||||
|
// 1. WriteCookie → 得到一個 cookie
|
||||||
|
w1 := httptest.NewRecorder()
|
||||||
|
if err := WriteCookie(w1, cfg, "sid-X"); err != nil {
|
||||||
|
t.Fatalf("WriteCookie: %v", err)
|
||||||
|
}
|
||||||
|
original := w1.Result().Cookies()[0]
|
||||||
|
|
||||||
|
// 2. ClearCookie → 得到一個 expiration cookie
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
ClearCookie(w2, cfg)
|
||||||
|
expirationCookie := w2.Result().Cookies()[0]
|
||||||
|
|
||||||
|
// 3. 模擬 browser 在收到 expiration cookie 後立刻發 request — 此時應該沒有 cookie
|
||||||
|
// (這裡無法直接模擬 browser 的 cookie jar 邏輯,但能驗證 expiration cookie 內容是空的、
|
||||||
|
// 若 browser 真的把它存下來,後續 ReadCookie 會失敗)。
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.AddCookie(&http.Cookie{Name: expirationCookie.Name, Value: expirationCookie.Value})
|
||||||
|
if _, ok := ReadCookie(r, cfg); ok {
|
||||||
|
t.Fatalf("after ClearCookie, ReadCookie of cleared value must fail")
|
||||||
|
}
|
||||||
|
// sanity check:原本的 cookie 仍然能讀(驗證 ReadCookie 本身沒壞)
|
||||||
|
r2 := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r2.AddCookie(original)
|
||||||
|
if _, ok := ReadCookie(r2, cfg); !ok {
|
||||||
|
t.Fatalf("baseline: original cookie should still read OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
// CookieConfig validation
|
||||||
|
// ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
func TestWriteCookie_MissingSigningKey(t *testing.T) {
|
||||||
|
cfg := CookieConfig{} // SigningKey 為 nil
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
err := WriteCookie(w, cfg, "sid")
|
||||||
|
if !errors.Is(err, ErrInvalidConfig) {
|
||||||
|
t.Fatalf("expected ErrInvalidConfig, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteCookie_EmptySessionID(t *testing.T) {
|
||||||
|
cfg := newCookieCfg()
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
err := WriteCookie(w, cfg, "")
|
||||||
|
if !errors.Is(err, ErrInvalidCookie) {
|
||||||
|
t.Fatalf("expected ErrInvalidCookie, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReadCookie_MissingSigningKey(t *testing.T) {
|
||||||
|
cfg := CookieConfig{} // SigningKey 為 nil
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
r.AddCookie(&http.Cookie{Name: DefaultCookieName, Value: "anything.thing"})
|
||||||
|
if _, ok := ReadCookie(r, cfg); ok {
|
||||||
|
t.Fatalf("ReadCookie should fail when SigningKey missing")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCookieConfig_Defaults(t *testing.T) {
|
||||||
|
cfg := CookieConfig{SigningKey: testKey} // 其他欄位都用預設
|
||||||
|
|
||||||
|
if cfg.resolvedName() != DefaultCookieName {
|
||||||
|
t.Fatalf("resolvedName default mismatch")
|
||||||
|
}
|
||||||
|
if cfg.resolvedPath() != "/" {
|
||||||
|
t.Fatalf("resolvedPath default should be '/'")
|
||||||
|
}
|
||||||
|
if cfg.resolvedSameSite() != http.SameSiteLaxMode {
|
||||||
|
t.Fatalf("resolvedSameSite default should be Lax")
|
||||||
|
}
|
||||||
|
}
|
||||||
27
visionA-backend/internal/usersession/errors.go
Normal file
27
visionA-backend/internal/usersession/errors.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package usersession
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
// 公開 sentinel errors,便於 caller 用 errors.Is 比對。
|
||||||
|
var (
|
||||||
|
// ErrNoSession 表示指定 ID 對應的 session 不存在(從未 Create、已 Delete、或已被 CleanupExpired 清掉)。
|
||||||
|
ErrNoSession = errors.New("usersession: not found")
|
||||||
|
|
||||||
|
// ErrSessionExpired 表示 session 雖存在但已逾時(idle 或 absolute)。
|
||||||
|
// 大多情境下 CleanupExpired 會先行移除,Get 直接回 ErrNoSession;
|
||||||
|
// 少數時序下 Manager.GetSession 可能在比對 timeout 後主動回此 error,便於 caller 區分。
|
||||||
|
ErrSessionExpired = errors.New("usersession: expired")
|
||||||
|
|
||||||
|
// ErrInvalidCookie 表示 cookie value 格式錯誤(缺 separator、欄位空、編碼失敗)。
|
||||||
|
ErrInvalidCookie = errors.New("usersession: invalid cookie")
|
||||||
|
|
||||||
|
// ErrSignatureMismatch 表示 cookie HMAC 簽章驗證失敗,可能被竄改或使用錯的 SigningKey。
|
||||||
|
ErrSignatureMismatch = errors.New("usersession: signature mismatch")
|
||||||
|
|
||||||
|
// ErrInvalidConfig 表示 CookieConfig 必填欄位缺漏(例如 SigningKey 為空)。
|
||||||
|
ErrInvalidConfig = errors.New("usersession: invalid config")
|
||||||
|
|
||||||
|
// ErrSigningKeyTooShort 表示 SigningKey 長度不足 MinSigningKeyBytes(32 bytes)。
|
||||||
|
// HMAC-SHA256 安全建議 key 長度至少等於 hash output(32 bytes / 256 bits)。
|
||||||
|
ErrSigningKeyTooShort = errors.New("usersession: signing key must be at least 32 bytes")
|
||||||
|
)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user