對齊 ADR-017 v1.2:模型庫下載走 visionA 簽 MC delegated token → Client 直連 FAA。
B2 — MC download token client(internal/fileaccess):
- DownloadTokenIssuer: GetServiceToken(打 MC /oauth/token,client_credentials +
scope files:download.delegate,含 token cache)+ IssueDownloadToken(打 MC Issue 簽 fdt_)
- secret / service token / fdt token 三層全程用 hashShort 遮罩不 log
- FileAccessConfig + VISIONA_FILE_ACCESS_* env + main.go wire(Enabled() 才接)
B1 — object_key 斷層:
- model.Model 加 FAAObjectKey(json:"-" 不揭露前端)
- PromoteToModels 寫入(用 promote response TargetObjectKey = models/{userID}/{jobID}.nef)
- 三方對映天然一致(visionA Issue / FAA path / MC validate)
- 第一階段框死只 Source=converted 類 model,上傳類 download 回 501
download endpoint:
- GET /api/models/:id/download(owner-only)→ {download_url, token, expires_at}
- 前端帶 Authorization: Bearer 直連 FAA(不經 visionA、不經 AWS)
- 401/403/404/501/502 分明,502 對外 mask 不洩漏 MC 內部狀態
測試: 13 + 8 unit test(mock MC + fake issuer,httptest 驗真 HTTP);go build/vet/test 全綠。
Reviewer: 0 Critical / 0 Major / 3 Minor / 4 Suggestion,通過。
技術債(正式上線前): 第一階段 PoC 共用 FAA service client,MC 規範禁止 client 混用
usage、secret 不共用,須 MC 配發 visionA 專屬 usage=file_api client。
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
361 lines
12 KiB
Go
361 lines
12 KiB
Go
package fileaccess
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"io"
|
||
"net/http"
|
||
"net/http/httptest"
|
||
"net/url"
|
||
"strings"
|
||
"sync/atomic"
|
||
"testing"
|
||
"time"
|
||
)
|
||
|
||
// ==========================================================================
|
||
// 測試輔助
|
||
// ==========================================================================
|
||
|
||
// newTestClient 用 httptest server base URL 建一個 *client(拿具體型別,方便測 cache)。
|
||
func newTestClient(t *testing.T, baseURL string, ttlSec int) *client {
|
||
t.Helper()
|
||
iss, err := NewClient(Opts{
|
||
MCBaseURL: baseURL,
|
||
ServiceClientID: "test-client-id",
|
||
ServiceClientSecret: "test-secret",
|
||
TenantID: "test-tenant",
|
||
DownloadTokenTTLSeconds: ttlSec,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewClient: %v", err)
|
||
}
|
||
c, ok := iss.(*client)
|
||
if !ok {
|
||
t.Fatalf("NewClient did not return *client, got %T", iss)
|
||
}
|
||
return c
|
||
}
|
||
|
||
// ==========================================================================
|
||
// NewClient
|
||
// ==========================================================================
|
||
|
||
func TestNewClient_returnsErrWhenConfigIncomplete(t *testing.T) {
|
||
cases := map[string]Opts{
|
||
"missing MCBaseURL": {ServiceClientID: "c", ServiceClientSecret: "s", TenantID: "t"},
|
||
"missing clientID": {MCBaseURL: "http://mc", ServiceClientSecret: "s", TenantID: "t"},
|
||
"missing clientSecret": {MCBaseURL: "http://mc", ServiceClientID: "c", TenantID: "t"},
|
||
"missing tenant": {MCBaseURL: "http://mc", ServiceClientID: "c", ServiceClientSecret: "s"},
|
||
}
|
||
for name, opts := range cases {
|
||
t.Run(name, func(t *testing.T) {
|
||
_, err := NewClient(opts)
|
||
if !errors.Is(err, ErrConfigIncomplete) {
|
||
t.Fatalf("want ErrConfigIncomplete, got %v", err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestNewClient_defaultsTTLWhenNonPositive(t *testing.T) {
|
||
iss, err := NewClient(Opts{
|
||
MCBaseURL: "http://mc", ServiceClientID: "c", ServiceClientSecret: "s", TenantID: "t",
|
||
DownloadTokenTTLSeconds: 0,
|
||
})
|
||
if err != nil {
|
||
t.Fatalf("NewClient: %v", err)
|
||
}
|
||
if got := iss.(*client).ttlSeconds; got != defaultDownloadTokenTTLSeconds {
|
||
t.Fatalf("ttl default: want %d, got %d", defaultDownloadTokenTTLSeconds, got)
|
||
}
|
||
}
|
||
|
||
// ==========================================================================
|
||
// GetServiceToken
|
||
// ==========================================================================
|
||
|
||
func TestGetServiceToken_successAndSendsClientCredentials(t *testing.T) {
|
||
var gotGrant, gotScope, gotClientID, gotSecret string
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
if r.URL.Path != oauthTokenPath {
|
||
t.Errorf("unexpected path %s", r.URL.Path)
|
||
}
|
||
if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") {
|
||
t.Errorf("want form content-type, got %s", ct)
|
||
}
|
||
body, _ := io.ReadAll(r.Body)
|
||
form, _ := url.ParseQuery(string(body))
|
||
gotGrant = form.Get("grant_type")
|
||
gotScope = form.Get("scope")
|
||
gotClientID = form.Get("client_id")
|
||
gotSecret = form.Get("client_secret")
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{
|
||
AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600,
|
||
Scope: downloadDelegateScope,
|
||
})
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
tok, err := c.GetServiceToken(context.Background())
|
||
if err != nil {
|
||
t.Fatalf("GetServiceToken: %v", err)
|
||
}
|
||
if tok != "svc-access-token" {
|
||
t.Fatalf("token: want svc-access-token, got %s", tok)
|
||
}
|
||
if gotGrant != "client_credentials" {
|
||
t.Errorf("grant_type: want client_credentials, got %s", gotGrant)
|
||
}
|
||
if gotScope != downloadDelegateScope {
|
||
t.Errorf("scope: want %s, got %s", downloadDelegateScope, gotScope)
|
||
}
|
||
if gotClientID != "test-client-id" || gotSecret != "test-secret" {
|
||
t.Errorf("client creds not sent: id=%s secret=%s", gotClientID, gotSecret)
|
||
}
|
||
}
|
||
|
||
func TestGetServiceToken_cachesTokenAcrossCalls(t *testing.T) {
|
||
var hits int32
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
atomic.AddInt32(&hits, 1)
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{
|
||
AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600,
|
||
})
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
for i := 0; i < 3; i++ {
|
||
if _, err := c.GetServiceToken(context.Background()); err != nil {
|
||
t.Fatalf("call %d: %v", i, err)
|
||
}
|
||
}
|
||
if got := atomic.LoadInt32(&hits); got != 1 {
|
||
t.Fatalf("want 1 oauth hit (cached), got %d", got)
|
||
}
|
||
}
|
||
|
||
func TestGetServiceToken_refetchesWhenExpired(t *testing.T) {
|
||
var hits int32
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
atomic.AddInt32(&hits, 1)
|
||
// expires_in 短到一定要重拿
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{
|
||
AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 10,
|
||
})
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
// 用可控時鐘:第一次 now=t0;第二次 now=t0+1h(必然過期)
|
||
base := time.Date(2026, 6, 7, 0, 0, 0, 0, time.UTC)
|
||
var cur time.Time = base
|
||
c.now = func() time.Time { return cur }
|
||
|
||
if _, err := c.GetServiceToken(context.Background()); err != nil {
|
||
t.Fatalf("first call: %v", err)
|
||
}
|
||
cur = base.Add(time.Hour)
|
||
if _, err := c.GetServiceToken(context.Background()); err != nil {
|
||
t.Fatalf("second call: %v", err)
|
||
}
|
||
if got := atomic.LoadInt32(&hits); got != 2 {
|
||
t.Fatalf("want 2 oauth hits (refetch after expiry), got %d", got)
|
||
}
|
||
}
|
||
|
||
func TestGetServiceToken_failsOnNon200(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
w.WriteHeader(http.StatusUnauthorized)
|
||
_, _ = w.Write([]byte(`{"error":"invalid_client"}`))
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
_, err := c.GetServiceToken(context.Background())
|
||
if !errors.Is(err, ErrServiceTokenFailed) {
|
||
t.Fatalf("want ErrServiceTokenFailed, got %v", err)
|
||
}
|
||
}
|
||
|
||
func TestGetServiceToken_failsOnMissingAccessToken(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{TokenType: "Bearer", ExpiresIn: 3600})
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
_, err := c.GetServiceToken(context.Background())
|
||
if !errors.Is(err, ErrServiceTokenFailed) {
|
||
t.Fatalf("want ErrServiceTokenFailed, got %v", err)
|
||
}
|
||
}
|
||
|
||
// ==========================================================================
|
||
// IssueDownloadToken
|
||
// ==========================================================================
|
||
|
||
func TestIssueDownloadToken_successFullChain(t *testing.T) {
|
||
var issueAuth string
|
||
var issueBody issueRequest
|
||
expiresAt := time.Date(2026, 6, 7, 12, 2, 0, 0, time.UTC)
|
||
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
switch r.URL.Path {
|
||
case oauthTokenPath:
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{
|
||
AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600,
|
||
})
|
||
case issuePath:
|
||
issueAuth = r.Header.Get("Authorization")
|
||
_ = json.NewDecoder(r.Body).Decode(&issueBody)
|
||
writeJSON(w, http.StatusOK, issueResponse{
|
||
Token: "fdt_abc123", TokenType: "file_download",
|
||
ExpiresAt: expiresAt.Format(time.RFC3339), Scope: "files:download.read",
|
||
})
|
||
default:
|
||
t.Errorf("unexpected path %s", r.URL.Path)
|
||
}
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
res, err := c.IssueDownloadToken(context.Background(), "user-oidc-sub", "models/u/j.nef")
|
||
if err != nil {
|
||
t.Fatalf("IssueDownloadToken: %v", err)
|
||
}
|
||
|
||
if res.Token != "fdt_abc123" {
|
||
t.Errorf("token: want fdt_abc123, got %s", res.Token)
|
||
}
|
||
if res.TokenType != "file_download" {
|
||
t.Errorf("token_type: want file_download, got %s", res.TokenType)
|
||
}
|
||
if res.Scope != "files:download.read" {
|
||
t.Errorf("scope: want files:download.read, got %s", res.Scope)
|
||
}
|
||
if !res.ExpiresAt.Equal(expiresAt) {
|
||
t.Errorf("expires_at: want %v, got %v", expiresAt, res.ExpiresAt)
|
||
}
|
||
// Issue 必須帶 service token 當 Bearer
|
||
if issueAuth != "Bearer svc-access-token" {
|
||
t.Errorf("issue auth header: want 'Bearer svc-access-token', got %q", issueAuth)
|
||
}
|
||
// Issue body 契約(ADR-017 §10)
|
||
if issueBody.TenantID != "test-tenant" {
|
||
t.Errorf("tenant_id: want test-tenant, got %s", issueBody.TenantID)
|
||
}
|
||
if issueBody.UserID != "user-oidc-sub" {
|
||
t.Errorf("user_id: want user-oidc-sub, got %s", issueBody.UserID)
|
||
}
|
||
if issueBody.ObjectKey != "models/u/j.nef" {
|
||
t.Errorf("object_key: want models/u/j.nef, got %s", issueBody.ObjectKey)
|
||
}
|
||
if issueBody.Method != "GET" {
|
||
t.Errorf("method: want GET, got %s", issueBody.Method)
|
||
}
|
||
if issueBody.ExpiresInSeconds != 120 {
|
||
t.Errorf("expires_in_seconds: want 120, got %d", issueBody.ExpiresInSeconds)
|
||
}
|
||
}
|
||
|
||
func TestIssueDownloadToken_validatesArgs(t *testing.T) {
|
||
c := newTestClient(t, "http://unused", 120)
|
||
if _, err := c.IssueDownloadToken(context.Background(), "", "key"); !errors.Is(err, ErrIssueTokenFailed) {
|
||
t.Errorf("empty userID: want ErrIssueTokenFailed, got %v", err)
|
||
}
|
||
if _, err := c.IssueDownloadToken(context.Background(), "user", ""); !errors.Is(err, ErrIssueTokenFailed) {
|
||
t.Errorf("empty objectKey: want ErrIssueTokenFailed, got %v", err)
|
||
}
|
||
}
|
||
|
||
func TestIssueDownloadToken_failsWhenServiceTokenFails(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
// oauth 永遠 500
|
||
w.WriteHeader(http.StatusInternalServerError)
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
_, err := c.IssueDownloadToken(context.Background(), "user", "key")
|
||
if !errors.Is(err, ErrServiceTokenFailed) {
|
||
t.Fatalf("want ErrServiceTokenFailed (propagated), got %v", err)
|
||
}
|
||
}
|
||
|
||
func TestIssueDownloadToken_failsOnIssueNon200(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
switch r.URL.Path {
|
||
case oauthTokenPath:
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600})
|
||
case issuePath:
|
||
w.WriteHeader(http.StatusForbidden)
|
||
_, _ = w.Write([]byte(`{"error":"forbidden"}`))
|
||
}
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
_, err := c.IssueDownloadToken(context.Background(), "user", "key")
|
||
if !errors.Is(err, ErrIssueTokenFailed) {
|
||
t.Fatalf("want ErrIssueTokenFailed, got %v", err)
|
||
}
|
||
}
|
||
|
||
func TestIssueDownloadToken_failsOnMissingToken(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
switch r.URL.Path {
|
||
case oauthTokenPath:
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600})
|
||
case issuePath:
|
||
writeJSON(w, http.StatusOK, issueResponse{TokenType: "file_download"}) // 無 token
|
||
}
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
_, err := c.IssueDownloadToken(context.Background(), "user", "key")
|
||
if !errors.Is(err, ErrIssueTokenFailed) {
|
||
t.Fatalf("want ErrIssueTokenFailed, got %v", err)
|
||
}
|
||
}
|
||
|
||
func TestIssueDownloadToken_toleratesBadExpiresAt(t *testing.T) {
|
||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
switch r.URL.Path {
|
||
case oauthTokenPath:
|
||
writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600})
|
||
case issuePath:
|
||
writeJSON(w, http.StatusOK, issueResponse{
|
||
Token: "fdt_x", TokenType: "file_download", ExpiresAt: "not-a-date",
|
||
})
|
||
}
|
||
}))
|
||
defer srv.Close()
|
||
|
||
c := newTestClient(t, srv.URL, 120)
|
||
res, err := c.IssueDownloadToken(context.Background(), "user", "key")
|
||
if err != nil {
|
||
t.Fatalf("should tolerate bad expires_at, got err %v", err)
|
||
}
|
||
if res.Token != "fdt_x" {
|
||
t.Errorf("token: want fdt_x, got %s", res.Token)
|
||
}
|
||
if !res.ExpiresAt.IsZero() {
|
||
t.Errorf("bad expires_at should leave zero time, got %v", res.ExpiresAt)
|
||
}
|
||
}
|
||
|
||
// ==========================================================================
|
||
// helpers(test-local)
|
||
// ==========================================================================
|
||
|
||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(status)
|
||
_ = json.NewEncoder(w).Encode(v)
|
||
}
|