diff --git a/pkg/mock/openid.go b/pkg/mock/openid.go index c8da196..39d04ba 100644 --- a/pkg/mock/openid.go +++ b/pkg/mock/openid.go @@ -10,6 +10,8 @@ import ( "strings" "time" + "github.com/nais/wonderwall/pkg/openid" + "github.com/alicebob/miniredis/v2" "github.com/go-chi/chi/v5" "github.com/google/uuid" @@ -352,33 +354,33 @@ func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request) func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWriter, r *http.Request) { if ip.Config.Provider().PushedAuthorizationRequestEndpoint() == "" { w.WriteHeader(http.StatusNotFound) - w.Write([]byte("PAR endpoint not supported")) + oauthError(w, fmt.Errorf("PAR endpoint not supported")) return } err := r.ParseForm() if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("malformed payload?")) + oauthError(w, fmt.Errorf("malformed payload")) return } if r.PostForm.Get("request_uri") != "" { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("request_uri should not be provided to PAR endpoint")) + oauthError(w, fmt.Errorf("request_uri should not be provided to PAR endpoint")) return } authorizeRequest, err := ip.parseAuthorizationRequest(r.PostForm) if err != nil { w.WriteHeader(http.StatusBadRequest) - w.Write([]byte(err.Error())) + oauthError(w, err) return } err = ip.validateClientAuthentication(w, r, r.PostForm.Get("client_id")) if err != nil { - w.Write([]byte(err.Error())) + oauthError(w, err) return } @@ -390,7 +392,18 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri w.Header().Set("content-type", "application/json") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(map[string]string{"request_uri": requestUri, "expires_in": "60"}) + json.NewEncoder(w).Encode(openid.PushedAuthorizationResponse{ + RequestUri: requestUri, + ExpiresIn: 60, + }) +} + +func oauthError(w http.ResponseWriter, err error) { + w.Header().Set("content-type", "application/json") + json.NewEncoder(w).Encode(openid.TokenErrorResponse{ + Error: "invalid_request", + ErrorDescription: err.Error(), + }) } func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/openid/client/login.go b/pkg/openid/client/login.go index 4dcc516..3a61517 100644 --- a/pkg/openid/client/login.go +++ b/pkg/openid/client/login.go @@ -1,12 +1,16 @@ package client import ( + "bytes" "context" "encoding/json" "errors" "fmt" + "io" "net/http" + urllib "net/url" "slices" + stringslib "strings" "golang.org/x/oauth2" @@ -159,15 +163,89 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest) authCodeURL = c.oauth2Config.AuthCodeURL(request.state, opts...) } else { - // TODO: implement PAR - // generate PAR request - // set all request parameters from authorizationRequest - // set client authentication parameters - - // perform POST to PAR endpoint - // extract request_uri from response - // generate auth code URL with request_uri and client_id - // set authCodeURL + params := map[string]string{ + "client_id": c.oauth2Config.ClientID, + "code_challenge": oauth2.S256ChallengeFromVerifier(request.codeVerifier), + "code_challenge_method": "S256", + "nonce": request.nonce, + "redirect_uri": request.callbackURL, + "response_mode": "query", + "response_type": "code", + "scope": stringslib.Join(c.oauth2Config.Scopes, " "), + "state": request.state, + } + + if resource := c.cfg.Client().ResourceIndicator(); resource != "" { + params["resource"] = resource + } + + if len(request.acr) > 0 { + params[LoginParameterMapping[SecurityLevelURLParameter]] = request.acr + } + + if len(request.locale) > 0 { + params[LoginParameterMapping[LocaleURLParameter]] = request.locale + } + + if len(request.prompt) > 0 { + params[PromptURLParameter] = request.prompt + params[MaxAgeURLParameter] = "0" + } + + authParams, err := c.AuthParams() + if err != nil { + return "", fmt.Errorf("generating client authentication parameters: %w", err) + } + + urlValues := authParams.URLValues(params) + + requestBody := stringslib.NewReader(urlValues.Encode()) + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := c.httpClient.Do(r) + if err != nil { + return "", fmt.Errorf("performing request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading server response: %w", err) + } + + if resp.StatusCode >= 400 && resp.StatusCode < 500 { + var errorResponse openid.TokenErrorResponse + if err := json.Unmarshal(body, &errorResponse); err != nil { + return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err) + } + return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription) + } else if resp.StatusCode >= 500 { + return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body) + } + + var pushedAuthorizationResponse openid.PushedAuthorizationResponse + if err := json.Unmarshal(body, &pushedAuthorizationResponse); err != nil { + return "", fmt.Errorf("unmarshalling token response: %w", err) + } + + v := urllib.Values{ + "client_id": {c.oauth2Config.ClientID}, + "request_uri": {pushedAuthorizationResponse.RequestUri}, + } + var buf bytes.Buffer + buf.WriteString(c.oauth2Config.Endpoint.AuthURL) + if stringslib.Contains(c.oauth2Config.Endpoint.AuthURL, "?") { + buf.WriteByte('&') + } else { + buf.WriteByte('?') + } + buf.WriteString(v.Encode()) + authCodeURL = buf.String() } return authCodeURL, nil diff --git a/pkg/openid/client/login_test.go b/pkg/openid/client/login_test.go index 9abf9d6..dc39b06 100644 --- a/pkg/openid/client/login_test.go +++ b/pkg/openid/client/login_test.go @@ -14,6 +14,28 @@ import ( urlpkg "github.com/nais/wonderwall/pkg/url" ) +func TestLogin_PushAuthorizationURL(t *testing.T) { + cfg := mock.Config() + idp := mock.NewIdentityProvider(cfg) + idp.OpenIDConfig.TestProvider.SetPushedAuthorizationRequestEndpoint(idp.ProviderServer.URL + "/par") + defer idp.Close() + req := idp.GetRequest(mock.Ingress + "/oauth2/login") + + result, err := idp.RelyingPartyHandler.Client.Login(req) + require.NoError(t, err) + + parsed, err := url.Parse(result.AuthCodeURL) + assert.NoError(t, err) + + query := parsed.Query() + assert.Contains(t, query, "request_uri") + assert.Contains(t, query, "client_id") + + assert.NotEmpty(t, query["request_uri"]) + assert.Contains(t, query["request_uri"][0], "urn:ietf:params:oauth:request_uri") + assert.ElementsMatch(t, query["client_id"], []string{idp.OpenIDConfig.Client().ClientID()}) +} + func TestLogin_URL(t *testing.T) { type loginURLTest struct { name string diff --git a/pkg/openid/oauth2.go b/pkg/openid/oauth2.go index 2fe5be5..ef05828 100644 --- a/pkg/openid/oauth2.go +++ b/pkg/openid/oauth2.go @@ -15,6 +15,12 @@ type TokenResponse struct { TokenType string `json:"token_type"` } +// PushedAuthorizationResponse is the struct representing the HTTP response from authorization servers as defined in RFC 9126, section 2.2. +type PushedAuthorizationResponse struct { + RequestUri string `json:"request_uri"` + ExpiresIn int64 `json:"expires_in"` +} + // TokenErrorResponse is the struct representing the HTTP error response returned from authorization servers as defined in RFC 6749, section 5.2. type TokenErrorResponse struct { Error string `json:"error"`