Skip to content

Commit

Permalink
feat: implement PAR for relying party
Browse files Browse the repository at this point in the history
Fixes #235

Co-authored-by: tronghn <[email protected]>
  • Loading branch information
sindrerh2 and tronghn committed Jan 22, 2025
1 parent 62ab4c1 commit 463f48a
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 15 deletions.
25 changes: 19 additions & 6 deletions pkg/mock/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"strings"
"time"

"github.com/nais/wonderwall/pkg/openid"

"github.com/alicebob/miniredis/v2"
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
Expand Down Expand Up @@ -352,33 +354,33 @@ func (ip *IdentityProviderHandler) Jwks(w http.ResponseWriter, r *http.Request)
func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWriter, r *http.Request) {
if ip.Config.Provider().PushedAuthorizationRequestEndpoint() == "" {
w.WriteHeader(http.StatusNotFound)
w.Write([]byte("PAR endpoint not supported"))
oauthError(w, fmt.Errorf("PAR endpoint not supported"))
return
}

err := r.ParseForm()
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("malformed payload?"))
oauthError(w, fmt.Errorf("malformed payload"))
return
}

if r.PostForm.Get("request_uri") != "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("request_uri should not be provided to PAR endpoint"))
oauthError(w, fmt.Errorf("request_uri should not be provided to PAR endpoint"))
return
}

authorizeRequest, err := ip.parseAuthorizationRequest(r.PostForm)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}

err = ip.validateClientAuthentication(w, r, r.PostForm.Get("client_id"))
if err != nil {
w.Write([]byte(err.Error()))
oauthError(w, err)
return
}

Expand All @@ -390,7 +392,18 @@ func (ip *IdentityProviderHandler) PushedAuthorizationRequest(w http.ResponseWri

w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]string{"request_uri": requestUri, "expires_in": "60"})
json.NewEncoder(w).Encode(openid.PushedAuthorizationResponse{
RequestUri: requestUri,
ExpiresIn: 60,
})
}

func oauthError(w http.ResponseWriter, err error) {
w.Header().Set("content-type", "application/json")
json.NewEncoder(w).Encode(openid.TokenErrorResponse{
Error: "invalid_request",
ErrorDescription: err.Error(),
})
}

func (ip *IdentityProviderHandler) Token(w http.ResponseWriter, r *http.Request) {
Expand Down
96 changes: 87 additions & 9 deletions pkg/openid/client/login.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
package client

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
urllib "net/url"
"slices"
stringslib "strings"

"golang.org/x/oauth2"

Expand Down Expand Up @@ -159,15 +163,89 @@ func (c *Client) authCodeURL(ctx context.Context, request *authorizationRequest)

authCodeURL = c.oauth2Config.AuthCodeURL(request.state, opts...)
} else {
// TODO: implement PAR
// generate PAR request
// set all request parameters from authorizationRequest
// set client authentication parameters

// perform POST to PAR endpoint
// extract request_uri from response
// generate auth code URL with request_uri and client_id
// set authCodeURL
params := map[string]string{
"client_id": c.oauth2Config.ClientID,
"code_challenge": oauth2.S256ChallengeFromVerifier(request.codeVerifier),
"code_challenge_method": "S256",
"nonce": request.nonce,
"redirect_uri": request.callbackURL,
"response_mode": "query",
"response_type": "code",
"scope": stringslib.Join(c.oauth2Config.Scopes, " "),
"state": request.state,
}

if resource := c.cfg.Client().ResourceIndicator(); resource != "" {
params["resource"] = resource
}

if len(request.acr) > 0 {
params[LoginParameterMapping[SecurityLevelURLParameter]] = request.acr
}

if len(request.locale) > 0 {
params[LoginParameterMapping[LocaleURLParameter]] = request.locale
}

if len(request.prompt) > 0 {
params[PromptURLParameter] = request.prompt
params[MaxAgeURLParameter] = "0"
}

authParams, err := c.AuthParams()
if err != nil {
return "", fmt.Errorf("generating client authentication parameters: %w", err)
}

urlValues := authParams.URLValues(params)

requestBody := stringslib.NewReader(urlValues.Encode())

r, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.Provider().PushedAuthorizationRequestEndpoint(), requestBody)
if err != nil {
return "", fmt.Errorf("creating request: %w", err)
}
r.Header.Set("Content-Type", "application/x-www-form-urlencoded")

resp, err := c.httpClient.Do(r)
if err != nil {
return "", fmt.Errorf("performing request: %w", err)
}
defer resp.Body.Close()

body, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("reading server response: %w", err)
}

if resp.StatusCode >= 400 && resp.StatusCode < 500 {
var errorResponse openid.TokenErrorResponse
if err := json.Unmarshal(body, &errorResponse); err != nil {
return "", fmt.Errorf("%w: HTTP %d: unmarshalling error response: %+v", ErrOpenIDClient, resp.StatusCode, err)
}
return "", fmt.Errorf("%w: HTTP %d: %s: %s", ErrOpenIDClient, resp.StatusCode, errorResponse.Error, errorResponse.ErrorDescription)
} else if resp.StatusCode >= 500 {
return "", fmt.Errorf("%w: HTTP %d: %s", ErrOpenIDServer, resp.StatusCode, body)
}

var pushedAuthorizationResponse openid.PushedAuthorizationResponse
if err := json.Unmarshal(body, &pushedAuthorizationResponse); err != nil {
return "", fmt.Errorf("unmarshalling token response: %w", err)
}

v := urllib.Values{
"client_id": {c.oauth2Config.ClientID},
"request_uri": {pushedAuthorizationResponse.RequestUri},
}
var buf bytes.Buffer
buf.WriteString(c.oauth2Config.Endpoint.AuthURL)
if stringslib.Contains(c.oauth2Config.Endpoint.AuthURL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
authCodeURL = buf.String()
}

return authCodeURL, nil
Expand Down
22 changes: 22 additions & 0 deletions pkg/openid/client/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@ import (
urlpkg "github.com/nais/wonderwall/pkg/url"
)

func TestLogin_PushAuthorizationURL(t *testing.T) {
cfg := mock.Config()
idp := mock.NewIdentityProvider(cfg)
idp.OpenIDConfig.TestProvider.SetPushedAuthorizationRequestEndpoint(idp.ProviderServer.URL + "/par")
defer idp.Close()
req := idp.GetRequest(mock.Ingress + "/oauth2/login")

result, err := idp.RelyingPartyHandler.Client.Login(req)
require.NoError(t, err)

parsed, err := url.Parse(result.AuthCodeURL)
assert.NoError(t, err)

query := parsed.Query()
assert.Contains(t, query, "request_uri")
assert.Contains(t, query, "client_id")

assert.NotEmpty(t, query["request_uri"])
assert.Contains(t, query["request_uri"][0], "urn:ietf:params:oauth:request_uri")
assert.ElementsMatch(t, query["client_id"], []string{idp.OpenIDConfig.Client().ClientID()})
}

func TestLogin_URL(t *testing.T) {
type loginURLTest struct {
name string
Expand Down
6 changes: 6 additions & 0 deletions pkg/openid/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ type TokenResponse struct {
TokenType string `json:"token_type"`
}

// PushedAuthorizationResponse is the struct representing the HTTP response from authorization servers as defined in RFC 9126, section 2.2.
type PushedAuthorizationResponse struct {
RequestUri string `json:"request_uri"`
ExpiresIn int64 `json:"expires_in"`
}

// TokenErrorResponse is the struct representing the HTTP error response returned from authorization servers as defined in RFC 6749, section 5.2.
type TokenErrorResponse struct {
Error string `json:"error"`
Expand Down

0 comments on commit 463f48a

Please sign in to comment.