package relay import ( "bufio" "bytes" "context" "encoding/base64" "encoding/json" "io" "log/slog" "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" "github.com/gorilla/websocket" "github.com/hashicorp/yamux" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "visiona-backend/internal/session" "visiona-backend/internal/wsconn" ) // TestEndToEnd_ForwardFromInternalToFakeLocal 模擬雛形完整 forwarding 路徑: // // api-server (internal HTTP client) // └─► POST /internal/forward/http ──► remote-proxy (this test's HTTP server) // └─► OpenStream 透過 tunnel ──► fake local agent // └─► 轉發到 fake local server (in-process httptest.Server) // // 這是 B3 任務 prompt 要求的「integration test」:通則代表整個 tunnel forwarding // 路徑可用,B4(api-server)與 B5(API handlers)可以安心往上疊。 // // 比對點: // - 請求能走完 internal → tunnel → local agent → local server // - local server 的 response body 能被讀回 api-server 端(base64 解碼) // - HTTP headers / status / body 都能保留 func TestEndToEnd_ForwardFromInternalToFakeLocal(t *testing.T) { // 1. 起一個 fake local server(模擬 `127.0.0.1:3721`) localSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Test-Route", r.URL.Path) w.WriteHeader(http.StatusOK) _ = json.NewEncoder(w).Encode(map[string]any{ "method": r.Method, "path": r.URL.Path, "ok": true, "receivedHeader": r.Header.Get("X-From-Api-Server"), }) })) defer localSrv.Close() localAddr := strings.TrimPrefix(localSrv.URL, "http://") // 2. 起 remote-proxy 的兩個 server(tunnel + internal) store := session.NewInMemoryStore() relaySrv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond}) internalSrv := NewInternalServer(store, slog.Default()) tunnelMux := http.NewServeMux() tunnelMux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect) tunnelTS := httptest.NewServer(tunnelMux) defer tunnelTS.Close() internalMux := http.NewServeMux() internalSrv.Routes(internalMux) internalTS := httptest.NewServer(internalMux) defer internalTS.Close() // 3. 起 fake tunnel client(模擬 POC edge-ai-server 或未來的 local agent), // 把 tunnel stream 上收到的 HTTP request 轉發給 localSrv。 const token = "vAc_feedbeeffeedbeeffeedbeeffeedbeef" stop := startTunnelClientForwardingTo(t, tunnelTS.URL, token, localAddr) defer stop() // 4. 等 session register require.Eventually(t, func() bool { ok, _ := store.Exists(context.Background(), token) return ok }, 2*time.Second, 20*time.Millisecond) // 5. 透過 /internal/forward/http 送一個請求 reqBody := ForwardHTTPRequest{ SessionToken: token, Method: http.MethodGet, Path: "/api/devices", Headers: map[string]string{ "X-From-Api-Server": "test-value", }, } bb, _ := json.Marshal(reqBody) resp, err := http.Post(internalTS.URL+"/internal/forward/http", "application/json", bytes.NewReader(bb)) require.NoError(t, err) defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) var fr ForwardHTTPResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&fr)) require.Nil(t, fr.Error, "forward should not error: %+v", fr.Error) require.Equal(t, http.StatusOK, fr.Status) // 6. 解 body 驗證 decoded, err := base64.StdEncoding.DecodeString(fr.Body) require.NoError(t, err) var payload map[string]any require.NoError(t, json.Unmarshal(decoded, &payload)) assert.Equal(t, "GET", payload["method"]) assert.Equal(t, "/api/devices", payload["path"]) assert.Equal(t, true, payload["ok"]) assert.Equal(t, "test-value", payload["receivedHeader"], "X-From-Api-Server header 應該被保留到 fake local server") // 7. 驗證 response header 也被保留 if vals, ok := fr.Headers["X-Test-Route"]; ok { assert.Equal(t, "/api/devices", vals[0]) } else { t.Errorf("expected X-Test-Route header to propagate; got: %+v", fr.Headers) } } // TestEndToEnd_Forward_TunnelDisconnected 當 token 沒對應 session 時回 error。 func TestEndToEnd_Forward_TunnelDisconnected(t *testing.T) { store := session.NewInMemoryStore() internalSrv := NewInternalServer(store, slog.Default()) mux := http.NewServeMux() internalSrv.Routes(mux) ts := httptest.NewServer(mux) defer ts.Close() reqBody := ForwardHTTPRequest{ SessionToken: "vAc_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", Method: http.MethodGet, Path: "/", } bb, _ := json.Marshal(reqBody) resp, err := http.Post(ts.URL+"/internal/forward/http", "application/json", bytes.NewReader(bb)) require.NoError(t, err) defer resp.Body.Close() assert.Equal(t, http.StatusBadGateway, resp.StatusCode) var fr ForwardHTTPResponse require.NoError(t, json.NewDecoder(resp.Body).Decode(&fr)) require.NotNil(t, fr.Error) assert.Equal(t, "TUNNEL_DISCONNECTED", fr.Error.Code) } // TestSessionCleanup_RemovesOnDisconnect 驗證當 local agent 斷線時 // session 會從 store 移除(tunnel.md §2.2 斷線處理)。 func TestSessionCleanup_RemovesOnDisconnect(t *testing.T) { store := session.NewInMemoryStore() relaySrv := NewServer(store, slog.Default(), Options{KeepAliveInterval: 500 * time.Millisecond}) mux := http.NewServeMux() mux.HandleFunc("/tunnel/connect", relaySrv.HandleTunnelConnect) ts := httptest.NewServer(mux) defer ts.Close() const token = "vAc_bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" stop := startFakeLocalAgent(t, "ws"+strings.TrimPrefix(ts.URL, "http")+"/tunnel/connect", token, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) require.Eventually(t, func() bool { ok, _ := store.Exists(context.Background(), token) return ok }, 2*time.Second, 20*time.Millisecond) stop() // 斷線 require.Eventually(t, func() bool { ok, _ := store.Exists(context.Background(), token) return !ok }, 2*time.Second, 20*time.Millisecond, "disconnected session should be unregistered") } // ---------------------------------------------------------------------- // Helper:起一個 fake tunnel client,把 tunnel stream 上收到的 HTTP request // 透過真實 TCP 轉發給 localAddr(完整模擬 local agent 行為) // ---------------------------------------------------------------------- func startTunnelClientForwardingTo(t *testing.T, relayHTTPURL, token, localAddr string) func() { t.Helper() wsURL := "ws" + strings.TrimPrefix(relayHTTPURL, "http") + "/tunnel/connect" u, err := url.Parse(wsURL) require.NoError(t, err) q := u.Query() q.Set("token", token) u.RawQuery = q.Encode() rawWS, _, err := websocket.DefaultDialer.Dial(u.String(), nil) require.NoError(t, err) netConn := wsconn.New(rawWS) ym, err := yamux.Client(netConn, yamux.DefaultConfig()) require.NoError(t, err) done := make(chan struct{}) go func() { defer close(done) for { stream, aerr := ym.Accept() if aerr != nil { return } go func(s net.Conn) { defer s.Close() req, rerr := http.ReadRequest(bufio.NewReader(s)) if rerr != nil { return } // 把 request forward 到 localAddr(真 TCP) req.URL.Scheme = "http" req.URL.Host = localAddr req.Host = localAddr req.RequestURI = "" resp, rerr := http.DefaultTransport.RoundTrip(req) if rerr != nil { errResp := &http.Response{ StatusCode: http.StatusBadGateway, ProtoMajor: 1, ProtoMinor: 1, Header: make(http.Header), Body: io.NopCloser(bytes.NewReader(nil)), } _ = errResp.Write(s) return } defer resp.Body.Close() _ = resp.Write(s) }(stream) } }() return func() { _ = ym.Close() _ = rawWS.Close() <-done } }