package fileaccess import ( "context" "encoding/json" "errors" "io" "net/http" "net/http/httptest" "net/url" "strings" "sync/atomic" "testing" "time" ) // ========================================================================== // 測試輔助 // ========================================================================== // newTestClient 用 httptest server base URL 建一個 *client(拿具體型別,方便測 cache)。 func newTestClient(t *testing.T, baseURL string, ttlSec int) *client { t.Helper() iss, err := NewClient(Opts{ MCBaseURL: baseURL, ServiceClientID: "test-client-id", ServiceClientSecret: "test-secret", TenantID: "test-tenant", DownloadTokenTTLSeconds: ttlSec, }) if err != nil { t.Fatalf("NewClient: %v", err) } c, ok := iss.(*client) if !ok { t.Fatalf("NewClient did not return *client, got %T", iss) } return c } // ========================================================================== // NewClient // ========================================================================== func TestNewClient_returnsErrWhenConfigIncomplete(t *testing.T) { cases := map[string]Opts{ "missing MCBaseURL": {ServiceClientID: "c", ServiceClientSecret: "s", TenantID: "t"}, "missing clientID": {MCBaseURL: "http://mc", ServiceClientSecret: "s", TenantID: "t"}, "missing clientSecret": {MCBaseURL: "http://mc", ServiceClientID: "c", TenantID: "t"}, "missing tenant": {MCBaseURL: "http://mc", ServiceClientID: "c", ServiceClientSecret: "s"}, } for name, opts := range cases { t.Run(name, func(t *testing.T) { _, err := NewClient(opts) if !errors.Is(err, ErrConfigIncomplete) { t.Fatalf("want ErrConfigIncomplete, got %v", err) } }) } } func TestNewClient_defaultsTTLWhenNonPositive(t *testing.T) { iss, err := NewClient(Opts{ MCBaseURL: "http://mc", ServiceClientID: "c", ServiceClientSecret: "s", TenantID: "t", DownloadTokenTTLSeconds: 0, }) if err != nil { t.Fatalf("NewClient: %v", err) } if got := iss.(*client).ttlSeconds; got != defaultDownloadTokenTTLSeconds { t.Fatalf("ttl default: want %d, got %d", defaultDownloadTokenTTLSeconds, got) } } // ========================================================================== // GetServiceToken // ========================================================================== func TestGetServiceToken_successAndSendsClientCredentials(t *testing.T) { var gotGrant, gotScope, gotClientID, gotSecret string srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != oauthTokenPath { t.Errorf("unexpected path %s", r.URL.Path) } if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/x-www-form-urlencoded") { t.Errorf("want form content-type, got %s", ct) } body, _ := io.ReadAll(r.Body) form, _ := url.ParseQuery(string(body)) gotGrant = form.Get("grant_type") gotScope = form.Get("scope") gotClientID = form.Get("client_id") gotSecret = form.Get("client_secret") writeJSON(w, http.StatusOK, oauthTokenResponse{ AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600, Scope: downloadDelegateScope, }) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) tok, err := c.GetServiceToken(context.Background()) if err != nil { t.Fatalf("GetServiceToken: %v", err) } if tok != "svc-access-token" { t.Fatalf("token: want svc-access-token, got %s", tok) } if gotGrant != "client_credentials" { t.Errorf("grant_type: want client_credentials, got %s", gotGrant) } if gotScope != downloadDelegateScope { t.Errorf("scope: want %s, got %s", downloadDelegateScope, gotScope) } if gotClientID != "test-client-id" || gotSecret != "test-secret" { t.Errorf("client creds not sent: id=%s secret=%s", gotClientID, gotSecret) } } func TestGetServiceToken_cachesTokenAcrossCalls(t *testing.T) { var hits int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hits, 1) writeJSON(w, http.StatusOK, oauthTokenResponse{ AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600, }) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) for i := 0; i < 3; i++ { if _, err := c.GetServiceToken(context.Background()); err != nil { t.Fatalf("call %d: %v", i, err) } } if got := atomic.LoadInt32(&hits); got != 1 { t.Fatalf("want 1 oauth hit (cached), got %d", got) } } func TestGetServiceToken_refetchesWhenExpired(t *testing.T) { var hits int32 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hits, 1) // expires_in 短到一定要重拿 writeJSON(w, http.StatusOK, oauthTokenResponse{ AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 10, }) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) // 用可控時鐘:第一次 now=t0;第二次 now=t0+1h(必然過期) base := time.Date(2026, 6, 7, 0, 0, 0, 0, time.UTC) var cur time.Time = base c.now = func() time.Time { return cur } if _, err := c.GetServiceToken(context.Background()); err != nil { t.Fatalf("first call: %v", err) } cur = base.Add(time.Hour) if _, err := c.GetServiceToken(context.Background()); err != nil { t.Fatalf("second call: %v", err) } if got := atomic.LoadInt32(&hits); got != 2 { t.Fatalf("want 2 oauth hits (refetch after expiry), got %d", got) } } func TestGetServiceToken_failsOnNon200(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusUnauthorized) _, _ = w.Write([]byte(`{"error":"invalid_client"}`)) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) _, err := c.GetServiceToken(context.Background()) if !errors.Is(err, ErrServiceTokenFailed) { t.Fatalf("want ErrServiceTokenFailed, got %v", err) } } func TestGetServiceToken_failsOnMissingAccessToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { writeJSON(w, http.StatusOK, oauthTokenResponse{TokenType: "Bearer", ExpiresIn: 3600}) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) _, err := c.GetServiceToken(context.Background()) if !errors.Is(err, ErrServiceTokenFailed) { t.Fatalf("want ErrServiceTokenFailed, got %v", err) } } // ========================================================================== // IssueDownloadToken // ========================================================================== func TestIssueDownloadToken_successFullChain(t *testing.T) { var issueAuth string var issueBody issueRequest expiresAt := time.Date(2026, 6, 7, 12, 2, 0, 0, time.UTC) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case oauthTokenPath: writeJSON(w, http.StatusOK, oauthTokenResponse{ AccessToken: "svc-access-token", TokenType: "Bearer", ExpiresIn: 3600, }) case issuePath: issueAuth = r.Header.Get("Authorization") _ = json.NewDecoder(r.Body).Decode(&issueBody) writeJSON(w, http.StatusOK, issueResponse{ Token: "fdt_abc123", TokenType: "file_download", ExpiresAt: expiresAt.Format(time.RFC3339), Scope: "files:download.read", }) default: t.Errorf("unexpected path %s", r.URL.Path) } })) defer srv.Close() c := newTestClient(t, srv.URL, 120) res, err := c.IssueDownloadToken(context.Background(), "user-oidc-sub", "models/u/j.nef") if err != nil { t.Fatalf("IssueDownloadToken: %v", err) } if res.Token != "fdt_abc123" { t.Errorf("token: want fdt_abc123, got %s", res.Token) } if res.TokenType != "file_download" { t.Errorf("token_type: want file_download, got %s", res.TokenType) } if res.Scope != "files:download.read" { t.Errorf("scope: want files:download.read, got %s", res.Scope) } if !res.ExpiresAt.Equal(expiresAt) { t.Errorf("expires_at: want %v, got %v", expiresAt, res.ExpiresAt) } // Issue 必須帶 service token 當 Bearer if issueAuth != "Bearer svc-access-token" { t.Errorf("issue auth header: want 'Bearer svc-access-token', got %q", issueAuth) } // Issue body 契約(ADR-017 §10) if issueBody.TenantID != "test-tenant" { t.Errorf("tenant_id: want test-tenant, got %s", issueBody.TenantID) } if issueBody.UserID != "user-oidc-sub" { t.Errorf("user_id: want user-oidc-sub, got %s", issueBody.UserID) } if issueBody.ObjectKey != "models/u/j.nef" { t.Errorf("object_key: want models/u/j.nef, got %s", issueBody.ObjectKey) } if issueBody.Method != "GET" { t.Errorf("method: want GET, got %s", issueBody.Method) } if issueBody.ExpiresInSeconds != 120 { t.Errorf("expires_in_seconds: want 120, got %d", issueBody.ExpiresInSeconds) } } func TestIssueDownloadToken_validatesArgs(t *testing.T) { c := newTestClient(t, "http://unused", 120) if _, err := c.IssueDownloadToken(context.Background(), "", "key"); !errors.Is(err, ErrIssueTokenFailed) { t.Errorf("empty userID: want ErrIssueTokenFailed, got %v", err) } if _, err := c.IssueDownloadToken(context.Background(), "user", ""); !errors.Is(err, ErrIssueTokenFailed) { t.Errorf("empty objectKey: want ErrIssueTokenFailed, got %v", err) } } func TestIssueDownloadToken_failsWhenServiceTokenFails(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // oauth 永遠 500 w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() c := newTestClient(t, srv.URL, 120) _, err := c.IssueDownloadToken(context.Background(), "user", "key") if !errors.Is(err, ErrServiceTokenFailed) { t.Fatalf("want ErrServiceTokenFailed (propagated), got %v", err) } } func TestIssueDownloadToken_failsOnIssueNon200(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case oauthTokenPath: writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600}) case issuePath: w.WriteHeader(http.StatusForbidden) _, _ = w.Write([]byte(`{"error":"forbidden"}`)) } })) defer srv.Close() c := newTestClient(t, srv.URL, 120) _, err := c.IssueDownloadToken(context.Background(), "user", "key") if !errors.Is(err, ErrIssueTokenFailed) { t.Fatalf("want ErrIssueTokenFailed, got %v", err) } } func TestIssueDownloadToken_failsOnMissingToken(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case oauthTokenPath: writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600}) case issuePath: writeJSON(w, http.StatusOK, issueResponse{TokenType: "file_download"}) // 無 token } })) defer srv.Close() c := newTestClient(t, srv.URL, 120) _, err := c.IssueDownloadToken(context.Background(), "user", "key") if !errors.Is(err, ErrIssueTokenFailed) { t.Fatalf("want ErrIssueTokenFailed, got %v", err) } } func TestIssueDownloadToken_toleratesBadExpiresAt(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case oauthTokenPath: writeJSON(w, http.StatusOK, oauthTokenResponse{AccessToken: "tok", ExpiresIn: 3600}) case issuePath: writeJSON(w, http.StatusOK, issueResponse{ Token: "fdt_x", TokenType: "file_download", ExpiresAt: "not-a-date", }) } })) defer srv.Close() c := newTestClient(t, srv.URL, 120) res, err := c.IssueDownloadToken(context.Background(), "user", "key") if err != nil { t.Fatalf("should tolerate bad expires_at, got err %v", err) } if res.Token != "fdt_x" { t.Errorf("token: want fdt_x, got %s", res.Token) } if !res.ExpiresAt.IsZero() { t.Errorf("bad expires_at should leave zero time, got %v", res.ExpiresAt) } } // ========================================================================== // helpers(test-local) // ========================================================================== func writeJSON(w http.ResponseWriter, status int, v any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(v) }