From 258571441abbcb864cf7c53717a281da5b981556 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Wed, 27 Aug 2025 09:09:50 +0100 Subject: [PATCH 1/3] Support Google introspection endpoint Google doesn't follow the usual RFC 7662 format but instead uses their own. This PR adds support for parsgin their format if using a well-known Google endpoint. To test: - Go to Google OAuth 2.0 Playground - Select the scopes you want (e.g. `https://www.googleapis.com/auth/userinfo.email`) - Click "Authorize APIs" and complete the OAuth flow You'll get an access token that you can copy and use in the inspector. Fixes: #1411 --- pkg/auth/token.go | 121 ++++++++++++++ pkg/auth/token_test.go | 358 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 476 insertions(+), 3 deletions(-) diff --git a/pkg/auth/token.go b/pkg/auth/token.go index d4da290db..b5e6554f8 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/url" + "strconv" "strings" "sync" "time" @@ -22,6 +23,12 @@ import ( "github.com/stacklok/toolhive/pkg/versions" ) +// Google OAuth endpoints +const ( + //nolint:gosec // This is a public API endpoint URL, not credentials + googleTokeninfoURL = "https://oauth2.googleapis.com/tokeninfo" +) + // Common errors var ( ErrNoToken = errors.New("no token provided") @@ -372,10 +379,124 @@ func parseIntrospectionClaims(r io.Reader) (jwt.MapClaims, error) { return claims, nil } +func parseGoogleTokeninfoClaims(r io.Reader) (jwt.MapClaims, error) { + var googleResp struct { + AZP string `json:"azp,omitempty"` + AUD string `json:"aud,omitempty"` + Sub string `json:"sub,omitempty"` + Scope string `json:"scope,omitempty"` + Exp string `json:"exp,omitempty"` // Google returns as string + ExpiresIn string `json:"expires_in,omitempty"` // Google returns as string + Email string `json:"email,omitempty"` + EmailVerified string `json:"email_verified,omitempty"` // Google returns as string ("true"/"false") + } + + if err := json.NewDecoder(r).Decode(&googleResp); err != nil { + return nil, fmt.Errorf("failed to decode Google tokeninfo JSON: %w", err) + } + + // Parse and validate expiration first + expInt, err := strconv.ParseInt(googleResp.Exp, 10, 64) + if err != nil || googleResp.Exp == "" { + return nil, ErrInvalidToken + } + + // Check if token is active (not expired) + currentTime := time.Now().Unix() + isActive := expInt > currentTime + if !isActive { + return nil, ErrTokenExpired + } + + claims := jwt.MapClaims{ + "active": true, + "exp": float64(expInt), // JWT expects float64 + "iss": "https://accounts.google.com", // Default issuer + } + + // Map standard fields that ToolHive uses + if googleResp.Sub != "" { + claims["sub"] = strings.TrimSpace(googleResp.Sub) + } + if googleResp.AUD != "" { + claims["aud"] = googleResp.AUD + } + if googleResp.Scope != "" { + claims["scope"] = strings.TrimSpace(googleResp.Scope) + } + + // Preserve Google-specific fields + if googleResp.Email != "" { + claims["email"] = googleResp.Email + } + if googleResp.EmailVerified != "" { + claims["email_verified"] = googleResp.EmailVerified + } + if googleResp.AZP != "" { + claims["azp"] = googleResp.AZP + } + if googleResp.ExpiresIn != "" { + claims["expires_in"] = googleResp.ExpiresIn + } + + return claims, nil +} + +func (v *TokenValidator) introspectGoogleToken(ctx context.Context, tokenStr string, + introspectionURL string) (jwt.MapClaims, error) { + // Parse the introspection URL and add query parameter safely + parsedURL, err := url.Parse(introspectionURL) + if err != nil { + return nil, fmt.Errorf("failed to parse introspection URL: %w", err) + } + + // Add access_token query parameter + query := parsedURL.Query() + query.Set("access_token", tokenStr) + parsedURL.RawQuery = query.Encode() + + tokeninfoURL := parsedURL.String() + + req, err := http.NewRequestWithContext(ctx, "GET", tokeninfoURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create Google tokeninfo request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", fmt.Sprintf("ToolHive/%s", versions.Version)) + + resp, err := v.client.Do(req) + if err != nil { + return nil, fmt.Errorf("google tokeninfo request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("google tokeninfo failed, status %d", resp.StatusCode) + } + + claims, err := parseGoogleTokeninfoClaims(resp.Body) + if err != nil { + return nil, err + } + + // Validate required claims (exp, iss, aud if configured) + if err := v.validateClaims(claims); err != nil { + return nil, err + } + + return claims, nil +} + func (v *TokenValidator) introspectOpaqueToken(ctx context.Context, tokenStr string) (jwt.MapClaims, error) { if v.introspectURL == "" { return nil, fmt.Errorf("no introspection endpoint available") } + + // Special case for Google tokeninfo endpoint + if v.introspectURL == googleTokeninfoURL { + return v.introspectGoogleToken(ctx, tokenStr, v.introspectURL) + } form := url.Values{"token": {tokenStr}} form.Set("token_type_hint", "access_token") diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index de6e4f340..956706be1 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -19,8 +19,11 @@ import ( "github.com/lestrrat-go/jwx/v3/jwk" ) -const testKeyID = "test-key-1" -const issuer = "https://issuer.example.com" +const ( + testKeyID = "test-key-1" + expClaim = "exp" + issuer = "https://issuer.example.com" +) //nolint:gocyclo // This test function is complex but manageable func TestTokenValidator(t *testing.T) { @@ -346,7 +349,7 @@ func TestTokenValidatorMiddleware(t *testing.T) { // Check the claims (except exp which might be formatted differently) for k, v := range tc.claims { - if k == "exp" { + if k == expClaim { // Skip exact comparison for exp claim continue } @@ -1093,6 +1096,120 @@ func TestMiddleware_WWWAuthenticate_NoHeader_And_WrongScheme(t *testing.T) { } } +func TestParseGoogleTokeninfoClaims(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + responseBody string + expectError bool + expectActive bool + expectedClaims map[string]interface{} + }{ + { + name: "valid Google tokeninfo response", + responseBody: `{ + "azp": "32553540559.apps.googleusercontent.com", + "aud": "32553540559.apps.googleusercontent.com", + "sub": "111260650121245072906", + "scope": "openid https://www.googleapis.com/auth/userinfo.email", + "exp": "` + fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix()) + `", + "expires_in": "3488", + "email": "user@example.com", + "email_verified": "true" + }`, + expectError: false, + expectActive: true, + expectedClaims: map[string]interface{}{ + "sub": "111260650121245072906", + "aud": "32553540559.apps.googleusercontent.com", + "scope": "openid https://www.googleapis.com/auth/userinfo.email", + "iss": "https://accounts.google.com", + "email": "user@example.com", + "email_verified": "true", + "azp": "32553540559.apps.googleusercontent.com", + "expires_in": "3488", + "active": true, + }, + }, + { + name: "expired Google token", + responseBody: `{ + "azp": "32553540559.apps.googleusercontent.com", + "aud": "32553540559.apps.googleusercontent.com", + "sub": "111260650121245072906", + "scope": "openid", + "exp": "` + fmt.Sprintf("%d", time.Now().Add(-time.Hour).Unix()) + `", + "email": "user@example.com" + }`, + expectError: true, + expectActive: false, + }, + { + name: "missing exp field", + responseBody: `{ + "azp": "32553540559.apps.googleusercontent.com", + "aud": "32553540559.apps.googleusercontent.com", + "sub": "111260650121245072906" + }`, + expectError: true, + expectActive: false, + }, + { + name: "invalid exp format", + responseBody: `{ + "azp": "32553540559.apps.googleusercontent.com", + "aud": "32553540559.apps.googleusercontent.com", + "sub": "111260650121245072906", + "exp": "invalid-timestamp" + }`, + expectError: true, + expectActive: false, + }, + { + name: "invalid JSON", + responseBody: `{invalid json`, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + reader := strings.NewReader(tc.responseBody) + claims, err := parseGoogleTokeninfoClaims(reader) + + if tc.expectError { + if err == nil { + t.Error("Expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("Expected no error but got: %v", err) + return + } + + // Verify expected claims + for key, expectedValue := range tc.expectedClaims { + if key == expClaim { + // Check that exp is set as float64 + if _, ok := claims["exp"].(float64); !ok { + t.Errorf("Expected exp to be float64, got %T", claims["exp"]) + } + continue + } + + if claims[key] != expectedValue { + t.Errorf("Expected claim %s to be %v, got %v", key, expectedValue, claims[key]) + } + } + }) + } +} + func TestMiddleware_WWWAuthenticate_InvalidOpaqueToken_NoIntrospectionConfigured(t *testing.T) { t.Parallel() @@ -1267,3 +1384,238 @@ func TestBuildWWWAuthenticate_Format(t *testing.T) { t.Fatalf("format mismatch:\nwant: %s\n got: %s", want, got) } } + +func TestIntrospectGoogleToken(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + token string + serverResponse func(w http.ResponseWriter, r *http.Request) + expectError bool + expectedClaims map[string]interface{} + }{ + { + name: "valid Google token", + token: "valid-google-token", + serverResponse: func(w http.ResponseWriter, r *http.Request) { + // Verify it's a GET request with correct query parameter + if r.Method != "GET" { + t.Errorf("Expected GET request, got %s", r.Method) + } + if token := r.URL.Query().Get("access_token"); token != "valid-google-token" { + t.Errorf("Expected access_token=valid-google-token, got %s", token) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "azp": "test-client.apps.googleusercontent.com", + "aud": "test-client.apps.googleusercontent.com", + "sub": "123456789", + "scope": "openid email", + "exp": fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix()), + "email": "test@example.com", + "email_verified": "true", + }) + }, + expectError: false, + expectedClaims: map[string]interface{}{ + "sub": "123456789", + "aud": "test-client.apps.googleusercontent.com", + "scope": "openid email", + "iss": "https://accounts.google.com", + "email": "test@example.com", + "email_verified": "true", + "azp": "test-client.apps.googleusercontent.com", + "active": true, + }, + }, + { + name: "Google returns 400 for invalid token", + token: "invalid-token", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "invalid_token", + "error_description": "Invalid token", + }) + }, + expectError: true, + }, + { + name: "Google returns expired token", + token: "expired-token", + serverResponse: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "azp": "test-client.apps.googleusercontent.com", + "aud": "test-client.apps.googleusercontent.com", + "sub": "123456789", + "scope": "openid email", + "exp": fmt.Sprintf("%d", time.Now().Add(-time.Hour).Unix()), // Expired + "email": "test@example.com", + }) + }, + expectError: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create a test server that mimics Google's tokeninfo endpoint + server := httptest.NewServer(http.HandlerFunc(tc.serverResponse)) + defer server.Close() + + // Create a validator with our test server URL + // Note: We're testing the introspectGoogleToken method directly, + // so the URL doesn't need to be the exact Google URL for this test + validator := &TokenValidator{ + client: http.DefaultClient, + } + + ctx := context.Background() + claims, err := validator.introspectGoogleToken(ctx, tc.token, server.URL) + + if tc.expectError { + if err == nil { + t.Error("Expected error but got nil") + } + return + } + + if err != nil { + t.Errorf("Expected no error but got: %v", err) + return + } + + // Verify expected claims + for key, expectedValue := range tc.expectedClaims { + if key == expClaim { + // Check that exp is set as float64 + if _, ok := claims["exp"].(float64); !ok { + t.Errorf("Expected exp to be float64, got %T", claims["exp"]) + } + continue + } + + if claims[key] != expectedValue { + t.Errorf("Expected claim %s to be %v, got %v", key, expectedValue, claims[key]) + } + } + }) + } +} + +func TestTokenValidator_GoogleTokeninfoIntegration(t *testing.T) { + t.Parallel() + + // Create a mock Google tokeninfo server + googleServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := r.URL.Query().Get("access_token") + + if token == "valid-google-token" { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "azp": "test-client.apps.googleusercontent.com", + "aud": "test-client.apps.googleusercontent.com", + "sub": "google-user-123", + "scope": "openid https://www.googleapis.com/auth/userinfo.email", + "exp": fmt.Sprintf("%d", time.Now().Add(time.Hour).Unix()), + "expires_in": "3600", + "email": "user@example.com", + "email_verified": "true", + }) + } else { + w.WriteHeader(http.StatusBadRequest) + json.NewEncoder(w).Encode(map[string]interface{}{ + "error": "invalid_token", + "error_description": "Invalid token", + }) + } + })) + t.Cleanup(func() { + googleServer.Close() + }) + + t.Run("Google tokeninfo direct call", func(t *testing.T) { //nolint:paralleltest // Server lifecycle requires sequential execution + // Note: Not using t.Parallel() here because we need the googleServer to stay alive + + // Test the introspectGoogleToken method directly with our test server + testValidator := &TokenValidator{ + client: http.DefaultClient, + issuer: "https://accounts.google.com", + audience: "test-client.apps.googleusercontent.com", + } + + // Directly call introspectGoogleToken to test Google-specific functionality + ctx := context.Background() + claims, err := testValidator.introspectGoogleToken(ctx, "valid-google-token", googleServer.URL) + if err != nil { + t.Fatalf("Expected no error but got: %v", err) + } + + // Verify Google-specific claims are properly handled + if claims["sub"] != "google-user-123" { + t.Errorf("Expected sub=google-user-123, got %v", claims["sub"]) + } + if claims["iss"] != "https://accounts.google.com" { + t.Errorf("Expected iss=https://accounts.google.com, got %v", claims["iss"]) + } + if claims["email"] != "user@example.com" { + t.Errorf("Expected email=user@example.com, got %v", claims["email"]) + } + if claims["active"] != true { + t.Errorf("Expected active=true, got %v", claims["active"]) + } + }) + + t.Run("routing logic test", func(t *testing.T) { + t.Parallel() + + // Test that the routing logic correctly detects Google's endpoint + // and routes to the Google-specific handler vs standard RFC 7662 + + ctx := context.Background() + + // Test 1: Google URL should route to Google handler (we can't easily test the full flow + // without mocking, but we can test that it attempts to use the Google method) + googleValidator := &TokenValidator{ + introspectURL: googleTokeninfoURL, + client: http.DefaultClient, + issuer: "https://accounts.google.com", + audience: "test-client.apps.googleusercontent.com", + } + + // This will fail because we can't reach the real Google endpoint, + // but it should fail in the HTTP request, not in the routing logic + _, err := googleValidator.introspectOpaqueToken(ctx, "test-token") + if err == nil { + t.Error("Expected error trying to reach real Google endpoint") + } + // The error should be about HTTP connection, not about routing + if !strings.Contains(err.Error(), "google tokeninfo") { + t.Logf("Got expected error attempting to use Google tokeninfo: %v", err) + } + + // Test 2: Non-Google URL should use standard RFC 7662 flow + standardValidator := &TokenValidator{ + introspectURL: googleServer.URL, // Our test server + client: http.DefaultClient, + issuer: "https://accounts.google.com", + audience: "test-client.apps.googleusercontent.com", + } + + // This should use the standard RFC 7662 POST method, which our test server doesn't handle + // So it should fail, but in a different way than the Google method + _, err = standardValidator.introspectOpaqueToken(ctx, "valid-google-token") + if err == nil { + t.Error("Expected error with non-Google introspection endpoint") + } + // Should fail because our test server expects GET but standard introspection uses POST + if strings.Contains(err.Error(), "google tokeninfo") { + t.Errorf("Should not use Google tokeninfo method for non-Google URL, got error: %v", err) + } + }) +} From f75c226cc36d3e74df36829fc22f2984ad9391e9 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 2 Sep 2025 11:31:28 +0100 Subject: [PATCH 2/3] Add providers --- pkg/auth/token.go | 476 ++++++++++++++++++++++++++++------------- pkg/auth/token_test.go | 50 +++-- 2 files changed, 356 insertions(+), 170 deletions(-) diff --git a/pkg/auth/token.go b/pkg/auth/token.go index b5e6554f8..663c2c47e 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -18,16 +18,303 @@ import ( "github.com/lestrrat-go/httprc/v3" "github.com/lestrrat-go/jwx/v3/jwk" + "github.com/stacklok/toolhive/pkg/auth/oauth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/networking" - "github.com/stacklok/toolhive/pkg/versions" ) -// Google OAuth endpoints -const ( - //nolint:gosec // This is a public API endpoint URL, not credentials - googleTokeninfoURL = "https://oauth2.googleapis.com/tokeninfo" -) +// TokenIntrospector defines the interface for token introspection providers +type TokenIntrospector interface { + // Name returns the provider name + Name() string + + // CanHandle returns true if this provider can handle the given introspection URL + CanHandle(introspectURL string) bool + + // IntrospectToken introspects an opaque token and returns JWT claims + IntrospectToken(ctx context.Context, token string) (jwt.MapClaims, error) +} + +// Registry maintains a list of available token introspection providers +type Registry struct { + providers []TokenIntrospector +} + +// NewRegistry creates a new provider registry +func NewRegistry() *Registry { + return &Registry{ + providers: []TokenIntrospector{}, + } +} + +// GetIntrospector returns the appropriate provider for the given introspection URL +func (r *Registry) GetIntrospector(introspectURL string) TokenIntrospector { + for _, provider := range r.providers { + if provider.CanHandle(introspectURL) { + logger.Debugf("Selected provider for introspection: %s (url: %s)", provider.Name(), introspectURL) + return provider + } + } + // Create a new fallback provider instance with the specific URL + logger.Debugf("Using RFC7662 fallback provider for introspection: %s", introspectURL) + return NewRFC7662Provider(introspectURL) +} + +// AddProvider adds a new provider to the registry +func (r *Registry) AddProvider(provider TokenIntrospector) { + r.providers = append(r.providers, provider) +} + +// GoogleTokeninfoURL is the Google OAuth2 tokeninfo endpoint URL +const GoogleTokeninfoURL = "https://oauth2.googleapis.com/tokeninfo" //nolint:gosec + +// GoogleProvider implements token introspection for Google's tokeninfo API +type GoogleProvider struct { + client *http.Client + url string +} + +// NewGoogleProvider creates a new Google token introspection provider +func NewGoogleProvider(introspectURL string) *GoogleProvider { + return &GoogleProvider{ + client: http.DefaultClient, + url: introspectURL, + } +} + +// Name returns the provider name +func (*GoogleProvider) Name() string { + return "google" +} + +// CanHandle returns true if this provider can handle the given introspection URL +func (g *GoogleProvider) CanHandle(introspectURL string) bool { + return introspectURL == g.url +} + +// IntrospectToken introspects a Google opaque token and returns JWT claims +func (g *GoogleProvider) IntrospectToken(ctx context.Context, token string) (jwt.MapClaims, error) { + logger.Debugf("Using Google tokeninfo provider for token introspection: %s", g.url) + + // Parse the URL and add query parameters + u, err := url.Parse(g.url) + if err != nil { + return nil, fmt.Errorf("failed to parse introspection URL: %w", err) + } + + // Add the access token as a query parameter + query := u.Query() + query.Set("access_token", token) + u.RawQuery = query.Encode() + + // Create the GET request + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + if err != nil { + return nil, fmt.Errorf("failed to create Google tokeninfo request: %w", err) + } + + req.Header.Set("Accept", "application/json") + req.Header.Set("User-Agent", oauth.UserAgent) + + // Make the request + resp, err := g.client.Do(req) + if err != nil { + return nil, fmt.Errorf("google tokeninfo request failed: %w", err) + } + defer resp.Body.Close() + + // Read the response + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read Google tokeninfo response: %w", err) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("google tokeninfo request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse the Google response and convert to JWT claims + logger.Debugf("Successfully received Google tokeninfo response (status: %d)", resp.StatusCode) + return g.parseGoogleResponse(body) +} + +// parseGoogleResponse parses Google's tokeninfo response and converts it to JWT claims +func (*GoogleProvider) parseGoogleResponse(body []byte) (jwt.MapClaims, error) { + // Parse Google's response format + var googleResp struct { + // Standard OAuth fields + Aud string `json:"aud,omitempty"` + Sub string `json:"sub,omitempty"` + Scope string `json:"scope,omitempty"` + + // Google returns Unix timestamp as string (RFC 7662 uses numeric) + Exp string `json:"exp,omitempty"` + + // Google-specific fields + Azp string `json:"azp,omitempty"` + ExpiresIn string `json:"expires_in,omitempty"` + Email string `json:"email,omitempty"` + EmailVerified string `json:"email_verified,omitempty"` + } + + if err := json.Unmarshal(body, &googleResp); err != nil { + return nil, fmt.Errorf("failed to decode Google tokeninfo JSON: %w", err) + } + + // Convert to JWT MapClaims format + claims := jwt.MapClaims{ + "iss": "https://accounts.google.com", // Default Google issuer + } + + // Copy standard fields + if googleResp.Sub != "" { + claims["sub"] = googleResp.Sub + } + if googleResp.Aud != "" { + claims["aud"] = googleResp.Aud + } + if googleResp.Scope != "" { + claims["scope"] = googleResp.Scope + } + + // Handle expiration - convert string timestamp to float64 + if googleResp.Exp != "" { + expInt, err := strconv.ParseInt(googleResp.Exp, 10, 64) + if err != nil { + return nil, fmt.Errorf("invalid exp format: %w", err) + } + claims["exp"] = float64(expInt) // JWT expects float64 + + // Check if token is expired and return error if so (consistent with RFC 7662 behavior) + isActive := time.Now().Unix() < expInt + claims["active"] = isActive + if !isActive { + return nil, ErrInvalidToken + } + } else { + return nil, fmt.Errorf("missing exp field in Google response") + } + + // Copy Google-specific fields + if googleResp.Azp != "" { + claims["azp"] = googleResp.Azp + } + if googleResp.ExpiresIn != "" { + claims["expires_in"] = googleResp.ExpiresIn + } + if googleResp.Email != "" { + claims["email"] = googleResp.Email + } + if googleResp.EmailVerified != "" { + claims["email_verified"] = googleResp.EmailVerified + } + + return claims, nil +} + +// RFC7662Provider implements standard RFC 7662 OAuth 2.0 Token Introspection +type RFC7662Provider struct { + client *http.Client + clientID string + clientSecret string + url string +} + +// NewRFC7662Provider creates a new RFC 7662 token introspection provider +func NewRFC7662Provider(introspectURL string) *RFC7662Provider { + return &RFC7662Provider{ + client: http.DefaultClient, + url: introspectURL, + } +} + +// NewRFC7662ProviderWithAuth creates a new RFC 7662 provider with client credentials +func NewRFC7662ProviderWithAuth( + introspectURL, clientID, clientSecret, caCertPath, authTokenFile string, allowPrivateIP bool, +) (*RFC7662Provider, error) { + // Create HTTP client with CA bundle and auth token support + client, err := networking.NewHttpClientBuilder(). + WithCABundle(caCertPath). + WithTokenFromFile(authTokenFile). + WithPrivateIPs(allowPrivateIP). + Build() + if err != nil { + return nil, fmt.Errorf("failed to create HTTP client: %w", err) + } + + return &RFC7662Provider{ + client: client, + clientID: clientID, + clientSecret: clientSecret, + url: introspectURL, + }, nil +} + +// Name returns the provider name +func (*RFC7662Provider) Name() string { + return "rfc7662" +} + +// CanHandle returns true if this provider can handle the given introspection URL +// Returns true for any URL when no specific URL was configured (fallback behavior) +// or when the URL matches the configured URL +func (r *RFC7662Provider) CanHandle(introspectURL string) bool { + // If no URL was configured, this is a fallback provider that handles everything + if r.url == "" { + return true + } + // Otherwise, only handle the specific configured URL + return r.url == introspectURL +} + +// IntrospectToken introspects a token using RFC 7662 standard +func (r *RFC7662Provider) IntrospectToken(ctx context.Context, token string) (jwt.MapClaims, error) { + // Prepare form data for POST request + formData := url.Values{} + formData.Set("token", token) + formData.Set("token_type_hint", "access_token") + + // Create POST request with form data + req, err := http.NewRequestWithContext(ctx, "POST", r.url, strings.NewReader(formData.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create introspection request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("User-Agent", oauth.UserAgent) + req.Header.Set("Accept", "application/json") + + // Add client authentication if configured + if r.clientID != "" && r.clientSecret != "" { + req.SetBasicAuth(r.clientID, r.clientSecret) + } + + // Make the request + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("introspection request failed: %w", err) + } + defer resp.Body.Close() + + // Read response body + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read introspection response: %w", err) + } + + // Check for HTTP errors + if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusUnauthorized { + return nil, fmt.Errorf("introspection unauthorized") + } + return nil, fmt.Errorf("introspection failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse RFC 7662 response - use the existing parseIntrospectionClaims function + return parseIntrospectionClaims(strings.NewReader(string(body))) +} // Common errors var ( @@ -65,6 +352,7 @@ type TokenValidator struct { introspectURL string // Optional introspection endpoint client *http.Client // HTTP client for making requests resourceURL string // (RFC 9728) + registry *Registry // Token introspection providers // Lazy JWKS registration jwksRegistered bool @@ -124,7 +412,7 @@ func discoverOIDCConfiguration( } // Set User-Agent header - req.Header.Set("User-Agent", fmt.Sprintf("ToolHive/%s", versions.Version)) + req.Header.Set("User-Agent", oauth.UserAgent) req.Header.Set("Accept", "application/json") // Create HTTP client with CA bundle and auth token support @@ -216,6 +504,26 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token // Skip synchronous JWKS registration - will be done lazily on first use + // Create provider registry with RFC7662 fallback + registry := NewRegistry() + + // Add Google provider if the introspection URL matches + if config.IntrospectionURL == GoogleTokeninfoURL { + logger.Debugf("Registering Google tokeninfo provider: %s", config.IntrospectionURL) + registry.AddProvider(NewGoogleProvider(config.IntrospectionURL)) + } + + // Add RFC7662 provider with auth if configured + if config.ClientID != "" || config.ClientSecret != "" { + rfc7662Provider, err := NewRFC7662ProviderWithAuth( + config.IntrospectionURL, config.ClientID, config.ClientSecret, config.CACertPath, config.AuthTokenFile, config.AllowPrivateIP, + ) + if err != nil { + return nil, fmt.Errorf("failed to create RFC7662 provider: %w", err) + } + registry.AddProvider(rfc7662Provider) + } + return &TokenValidator{ issuer: config.Issuer, audience: config.Audience, @@ -226,6 +534,7 @@ func NewTokenValidator(ctx context.Context, config TokenValidatorConfig) (*Token jwksClient: cache, client: config.httpClient, resourceURL: config.ResourceURL, + registry: registry, }, nil } @@ -379,162 +688,29 @@ func parseIntrospectionClaims(r io.Reader) (jwt.MapClaims, error) { return claims, nil } -func parseGoogleTokeninfoClaims(r io.Reader) (jwt.MapClaims, error) { - var googleResp struct { - AZP string `json:"azp,omitempty"` - AUD string `json:"aud,omitempty"` - Sub string `json:"sub,omitempty"` - Scope string `json:"scope,omitempty"` - Exp string `json:"exp,omitempty"` // Google returns as string - ExpiresIn string `json:"expires_in,omitempty"` // Google returns as string - Email string `json:"email,omitempty"` - EmailVerified string `json:"email_verified,omitempty"` // Google returns as string ("true"/"false") - } - - if err := json.NewDecoder(r).Decode(&googleResp); err != nil { - return nil, fmt.Errorf("failed to decode Google tokeninfo JSON: %w", err) - } - - // Parse and validate expiration first - expInt, err := strconv.ParseInt(googleResp.Exp, 10, 64) - if err != nil || googleResp.Exp == "" { - return nil, ErrInvalidToken - } - - // Check if token is active (not expired) - currentTime := time.Now().Unix() - isActive := expInt > currentTime - if !isActive { - return nil, ErrTokenExpired - } - - claims := jwt.MapClaims{ - "active": true, - "exp": float64(expInt), // JWT expects float64 - "iss": "https://accounts.google.com", // Default issuer - } - - // Map standard fields that ToolHive uses - if googleResp.Sub != "" { - claims["sub"] = strings.TrimSpace(googleResp.Sub) - } - if googleResp.AUD != "" { - claims["aud"] = googleResp.AUD - } - if googleResp.Scope != "" { - claims["scope"] = strings.TrimSpace(googleResp.Scope) - } - - // Preserve Google-specific fields - if googleResp.Email != "" { - claims["email"] = googleResp.Email - } - if googleResp.EmailVerified != "" { - claims["email_verified"] = googleResp.EmailVerified - } - if googleResp.AZP != "" { - claims["azp"] = googleResp.AZP - } - if googleResp.ExpiresIn != "" { - claims["expires_in"] = googleResp.ExpiresIn - } - - return claims, nil -} - -func (v *TokenValidator) introspectGoogleToken(ctx context.Context, tokenStr string, - introspectionURL string) (jwt.MapClaims, error) { - // Parse the introspection URL and add query parameter safely - parsedURL, err := url.Parse(introspectionURL) - if err != nil { - return nil, fmt.Errorf("failed to parse introspection URL: %w", err) - } - - // Add access_token query parameter - query := parsedURL.Query() - query.Set("access_token", tokenStr) - parsedURL.RawQuery = query.Encode() - - tokeninfoURL := parsedURL.String() - - req, err := http.NewRequestWithContext(ctx, "GET", tokeninfoURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create Google tokeninfo request: %w", err) - } - - req.Header.Set("Accept", "application/json") - req.Header.Set("User-Agent", fmt.Sprintf("ToolHive/%s", versions.Version)) - - resp, err := v.client.Do(req) - if err != nil { - return nil, fmt.Errorf("google tokeninfo request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("google tokeninfo failed, status %d", resp.StatusCode) - } - - claims, err := parseGoogleTokeninfoClaims(resp.Body) - if err != nil { - return nil, err - } - - // Validate required claims (exp, iss, aud if configured) - if err := v.validateClaims(claims); err != nil { - return nil, err - } - - return claims, nil -} - +// introspectOpaqueToken uses the provider pattern to introspect opaque tokens func (v *TokenValidator) introspectOpaqueToken(ctx context.Context, tokenStr string) (jwt.MapClaims, error) { if v.introspectURL == "" { return nil, fmt.Errorf("no introspection endpoint available") } - // Special case for Google tokeninfo endpoint - if v.introspectURL == googleTokeninfoURL { - return v.introspectGoogleToken(ctx, tokenStr, v.introspectURL) - } - form := url.Values{"token": {tokenStr}} - form.Set("token_type_hint", "access_token") - - // Build POST request with encoding and required headers - req, err := http.NewRequestWithContext(ctx, "POST", v.introspectURL, strings.NewReader(form.Encode())) - if err != nil { - return nil, fmt.Errorf("failed to create introspection request: %w", err) - } - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("Accept", "application/json") - - // if we have client id and secret, add them to the request - if v.clientID != "" && v.clientSecret != "" { - req.SetBasicAuth(v.clientID, v.clientSecret) - } - - resp, err := v.client.Do(req) - if err != nil { - return nil, fmt.Errorf("introspection call failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode == http.StatusUnauthorized { - return nil, fmt.Errorf("introspection unauthorized: %s", resp.Status) - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("introspection failed, status %d", resp.StatusCode) + // Find appropriate provider for the introspection URL + provider := v.registry.GetIntrospector(v.introspectURL) + if provider == nil { + return nil, fmt.Errorf("no provider available for introspection URL: %s", v.introspectURL) } - claims, err := parseIntrospectionClaims(resp.Body) + // Use provider to introspect the token + claims, err := provider.IntrospectToken(ctx, tokenStr) if err != nil { - return nil, err + return nil, fmt.Errorf("%s introspection failed: %w", provider.Name(), err) } - // Validate required claims (e.g. exp) + // Validate required claims (exp, iss, aud if configured) if err := v.validateClaims(claims); err != nil { return nil, err } + return claims, nil } diff --git a/pkg/auth/token_test.go b/pkg/auth/token_test.go index 956706be1..e2286c8c2 100644 --- a/pkg/auth/token_test.go +++ b/pkg/auth/token_test.go @@ -745,6 +745,12 @@ func TestTokenValidator_OpaqueToken(t *testing.T) { ctx := context.Background() // Create a token validator that only uses introspection (no JWKS URL) + registry := NewRegistry() + registry.AddProvider(NewGoogleProvider(GoogleTokeninfoURL)) + // Use the basic RFC7662 provider for tests (no custom networking restrictions) + rfc7662Provider := NewRFC7662Provider(introspectionServer.URL) + registry.AddProvider(rfc7662Provider) + validator := &TokenValidator{ introspectURL: introspectionServer.URL, clientID: "test-client-id", @@ -752,6 +758,7 @@ func TestTokenValidator_OpaqueToken(t *testing.T) { client: http.DefaultClient, issuer: "opaque-issuer", audience: "opaque-audience", + registry: registry, } t.Run("valid opaque token", func(t *testing.T) { @@ -1046,6 +1053,7 @@ func TestMiddleware_WWWAuthenticate_NoHeader_And_WrongScheme(t *testing.T) { tv := &TokenValidator{ issuer: issuer, resourceURL: resourceMeta, + registry: NewRegistry(), } hitDownstream := false @@ -1177,8 +1185,16 @@ func TestParseGoogleTokeninfoClaims(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - reader := strings.NewReader(tc.responseBody) - claims, err := parseGoogleTokeninfoClaims(reader) + // Test the provider's parsing by creating a mock server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, tc.responseBody) + })) + defer server.Close() + + provider := NewGoogleProvider(server.URL) + claims, err := provider.IntrospectToken(context.Background(), "dummy-token") if tc.expectError { if err == nil { @@ -1214,7 +1230,8 @@ func TestMiddleware_WWWAuthenticate_InvalidOpaqueToken_NoIntrospectionConfigured t.Parallel() tv := &TokenValidator{ - issuer: issuer, + issuer: issuer, + registry: NewRegistry(), // introspectURL intentionally empty to force the error path } @@ -1323,6 +1340,7 @@ func TestMiddleware_WWWAuthenticate_WithMockIntrospection(t *testing.T) { clientID: "cid", clientSecret: "csecret", client: http.DefaultClient, + registry: NewRegistry(), } hit := false @@ -1468,15 +1486,11 @@ func TestIntrospectGoogleToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(tc.serverResponse)) defer server.Close() - // Create a validator with our test server URL - // Note: We're testing the introspectGoogleToken method directly, - // so the URL doesn't need to be the exact Google URL for this test - validator := &TokenValidator{ - client: http.DefaultClient, - } + // Use the Google provider directly for testing + provider := NewGoogleProvider(server.URL) ctx := context.Background() - claims, err := validator.introspectGoogleToken(ctx, tc.token, server.URL) + claims, err := provider.IntrospectToken(ctx, tc.token) if tc.expectError { if err == nil { @@ -1542,16 +1556,10 @@ func TestTokenValidator_GoogleTokeninfoIntegration(t *testing.T) { t.Run("Google tokeninfo direct call", func(t *testing.T) { //nolint:paralleltest // Server lifecycle requires sequential execution // Note: Not using t.Parallel() here because we need the googleServer to stay alive - // Test the introspectGoogleToken method directly with our test server - testValidator := &TokenValidator{ - client: http.DefaultClient, - issuer: "https://accounts.google.com", - audience: "test-client.apps.googleusercontent.com", - } - - // Directly call introspectGoogleToken to test Google-specific functionality + // Use Google provider to test Google-specific functionality + provider := NewGoogleProvider(googleServer.URL) ctx := context.Background() - claims, err := testValidator.introspectGoogleToken(ctx, "valid-google-token", googleServer.URL) + claims, err := provider.IntrospectToken(ctx, "valid-google-token") if err != nil { t.Fatalf("Expected no error but got: %v", err) } @@ -1582,10 +1590,11 @@ func TestTokenValidator_GoogleTokeninfoIntegration(t *testing.T) { // Test 1: Google URL should route to Google handler (we can't easily test the full flow // without mocking, but we can test that it attempts to use the Google method) googleValidator := &TokenValidator{ - introspectURL: googleTokeninfoURL, + introspectURL: GoogleTokeninfoURL, client: http.DefaultClient, issuer: "https://accounts.google.com", audience: "test-client.apps.googleusercontent.com", + registry: NewRegistry(), } // This will fail because we can't reach the real Google endpoint, @@ -1605,6 +1614,7 @@ func TestTokenValidator_GoogleTokeninfoIntegration(t *testing.T) { client: http.DefaultClient, issuer: "https://accounts.google.com", audience: "test-client.apps.googleusercontent.com", + registry: NewRegistry(), } // This should use the standard RFC 7662 POST method, which our test server doesn't handle From 742de23a4180129896b259353713779bcb2c0e6b Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Thu, 4 Sep 2025 11:22:23 +0100 Subject: [PATCH 3/3] Use LimitReader to avoid DoS by evil endpoints --- pkg/auth/token.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 663c2c47e..38492eea3 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -123,8 +123,10 @@ func (g *GoogleProvider) IntrospectToken(ctx context.Context, token string) (jwt } defer resp.Body.Close() - // Read the response - body, err := io.ReadAll(resp.Body) + // Read the response with a reasonable limit to prevent DoS attacks + const maxResponseSize = 64 * 1024 // 64KB should be more than enough for tokeninfo response + limitedReader := io.LimitReader(resp.Body, maxResponseSize) + body, err := io.ReadAll(limitedReader) if err != nil { return nil, fmt.Errorf("failed to read Google tokeninfo response: %w", err) } @@ -298,8 +300,10 @@ func (r *RFC7662Provider) IntrospectToken(ctx context.Context, token string) (jw } defer resp.Body.Close() - // Read response body - body, err := io.ReadAll(resp.Body) + // Read response body with a reasonable limit to prevent DoS attacks + const maxResponseSize = 64 * 1024 // 64KB should be more than enough for introspection response + limitedReader := io.LimitReader(resp.Body, maxResponseSize) + body, err := io.ReadAll(limitedReader) if err != nil { return nil, fmt.Errorf("failed to read introspection response: %w", err) }