Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add password verification hook #1328

Merged
merged 14 commits into from
Dec 4, 2023
29 changes: 27 additions & 2 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,6 @@ func (a *API) runHook(ctx context.Context, name string, input, output any) ([]by

func (a *API) invokeHook(ctx context.Context, input, output any) error {
config := a.config

switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
Expand All @@ -263,6 +262,32 @@ func (a *API) invokeHook(ctx context.Context, input, output any) error {
return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.PasswordVerificationAttemptInput:
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.PasswordVerificationAttemptOutput")
J0 marked this conversation as resolved.
Show resolved Hide resolved
}

if _, err := a.runHook(ctx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
Code: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil

default:
Expand Down Expand Up @@ -335,7 +360,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
return err
}

if output.Decision == hooks.MFAHookRejection {
if output.Decision == hooks.HookRejection {
if err := models.Logout(a.db, user.ID); err != nil {
return err
}
Expand Down
27 changes: 26 additions & 1 deletion internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/golang-jwt/jwt"

"github.com/supabase/gotrue/internal/conf"
"github.com/supabase/gotrue/internal/hooks"
"github.com/supabase/gotrue/internal/metering"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
Expand Down Expand Up @@ -140,7 +141,31 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri
return internalServerError("Database error querying schema").WithInternalError(err)
}

if user.IsBanned() || !user.Authenticate(ctx, params.Password) {
if user.IsBanned() {
return oauthError("invalid_grant", InvalidLoginMessage)

J0 marked this conversation as resolved.
Show resolved Hide resolved
}
isValidPassword := user.Authenticate(ctx, params.Password)
if config.Hook.PasswordVerificationAttempt.Enabled {

input := hooks.PasswordVerificationAttemptInput{
UserID: user.ID,
Valid: isValidPassword,
}
output := hooks.PasswordVerificationAttemptOutput{}
err := a.invokeHook(ctx, &input, &output)
if err != nil {
return err
}

if output.Decision == hooks.HookRejection {
if output.Message == "" {
J0 marked this conversation as resolved.
Show resolved Hide resolved
output.Message = hooks.DefaultPasswordHookRejectionMessage
}
return forbiddenError(output.Message)
}
}
if !isValidPassword {
return oauthError("invalid_grant", InvalidLoginMessage)
}

Expand Down
61 changes: 61 additions & 0 deletions internal/api/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -588,3 +588,64 @@ func (ts *TokenTestSuite) TestMagicLinkPKCESignIn() {
require.NotEmpty(ts.T(), verifyResp.Token)

}

func (ts *TokenTestSuite) TestPasswordVerificationHook() {
type verificationHookTestcase struct {
desc string
uri string
hookFunctionSQL string
expectedCode int
}
cases := []verificationHookTestcase{
{
desc: "Default success",
uri: "pg-functions://postgres/auth/password_verification_hook",
hookFunctionSQL: `
create or replace function password_verification_hook(input jsonb)
returns json as $$
begin
return json_build_object('decision', 'continue');
end; $$ language plpgsql;`,
expectedCode: http.StatusOK,
}, {
desc: "Reject- Enabled",
uri: "pg-functions://postgres/auth/password_verification_hook_reject",
hookFunctionSQL: `
create or replace function password_verification_hook_reject(input jsonb)
returns json as $$
begin
return json_build_object('decision', 'reject');
end; $$ language plpgsql;`,
expectedCode: http.StatusForbidden,
},
}
for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.PasswordVerificationAttempt.Enabled = true
ts.Config.Hook.PasswordVerificationAttempt.URI = c.uri
require.NoError(ts.T(), ts.Config.Hook.PasswordVerificationAttempt.ValidateAndPopulateExtensibilityPoint())

err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
require.NoError(t, err)
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": "[email protected]",
"password": "password",
}))

req := httptest.NewRequest(http.MethodPost, "http://localhost/token?grant_type=password", &buffer)
req.Header.Set("Content-Type", "application/json")

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)

assert.Equal(ts.T(), c.expectedCode, w.Code)
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", ts.Config.Hook.PasswordVerificationAttempt.HookName)
require.NoError(ts.T(), ts.API.db.RawQuery(cleanupHookSQL).Exec())
// Reset so it doesn't affect other tests
ts.Config.Hook.PasswordVerificationAttempt.Enabled = false

})
}

}
4 changes: 3 additions & 1 deletion internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,8 @@ type WebhookConfig struct {

// Moving away from the existing HookConfig so we can get a fresh start.
type HookConfiguration struct {
MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"`
MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"`
PasswordVerificationAttempt ExtensibilityPointConfiguration `json:"password_verification_attempt" split_words:"true"`
}

type ExtensibilityPointConfiguration struct {
Expand All @@ -453,6 +454,7 @@ type ExtensibilityPointConfiguration struct {
func (h *HookConfiguration) Validate() error {
points := []ExtensibilityPointConfiguration{
h.MFAVerificationAttempt,
h.PasswordVerificationAttempt,
}
for _, point := range points {
if err := point.ValidateAndPopulateExtensibilityPoint(); err != nil {
Expand Down
30 changes: 24 additions & 6 deletions internal/hooks/auth_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ const (

// Hook Names
const (
MFAHookRejection = "reject"
MFAHookContinue = "continue"
HookRejection = "reject"
)

type HookOutput interface {
Expand All @@ -33,10 +32,20 @@ type MFAVerificationAttemptInput struct {
}

type MFAVerificationAttemptOutput struct {
Decision string `json:"decision,omitempty"`
Message string `json:"message,omitempty"`
Decision string `json:"decision"`
Message string `json:"message"`
HookError AuthHookError `json:"error"`
}

type PasswordVerificationAttemptInput struct {
UserID uuid.UUID `json:"user_id"`
Valid bool `json:"valid"`
}

HookError AuthHookError `json:"error,omitempty"`
type PasswordVerificationAttemptOutput struct {
Decision string `json:"decision"`
Message string `json:"message"`
HookError AuthHookError `json:"error"`
}

func (mf *MFAVerificationAttemptOutput) IsError() bool {
Expand All @@ -47,6 +56,14 @@ func (mf *MFAVerificationAttemptOutput) Error() string {
return mf.HookError.Message
}

func (p *PasswordVerificationAttemptOutput) IsError() bool {
return p.HookError.Message != ""
}

func (p *PasswordVerificationAttemptOutput) Error() string {
return p.HookError.Message
}

type AuthHookError struct {
HTTPCode int `json:"http_code,omitempty"`
Message string `json:"message,omitempty"`
Expand All @@ -57,5 +74,6 @@ func (a *AuthHookError) Error() string {
}

const (
DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected."
DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected."
DefaultPasswordHookRejectionMessage = "Further password verification attempts will be rejected."
)
Loading