Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add mfa verification postgres hook #1314

Merged
merged 46 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
67ecb95
feat: initial commit
Nov 20, 2023
6078578
fix: add scaffolding
Nov 20, 2023
51e1987
feat: generate payload
Nov 20, 2023
1c15e18
feat: set up calling mechanism
Nov 20, 2023
c62ff8a
feat: add more surrounding logic
Nov 20, 2023
fa1122a
fix: reinstate relevant constants
Nov 21, 2023
da86cde
feat: add minor validation and cleanup
Nov 21, 2023
185b1d7
feat: add hook configuration to settings
Nov 21, 2023
6676552
feat: update naming conventions
Nov 22, 2023
9493543
feat:refactor to do w/o abstraction
Nov 24, 2023
4b08570
fix: remove now redundant methods
Nov 24, 2023
78e184c
fix: add some logging
Nov 24, 2023
f3ae14b
fix: remove unused code
Nov 24, 2023
8708df0
Merge branch 'master' into j0/mfa_verification_counter_hook
J0 Nov 24, 2023
dc9d6da
fix: reset unused files
Nov 24, 2023
1cd6eff
feat: add stubs, revert unneeded changes
Nov 24, 2023
3a93440
Merge branch 'j0/mfa_verification_counter_hook' of github.com:supabas…
Nov 24, 2023
8a3ec9c
refactor: rename hook ep
Nov 27, 2023
e21a806
refactor: rename HookErrorResponse->AuthHookErrorResponse
Nov 27, 2023
b7a8c23
test: add initial tests
Nov 27, 2023
5a4ce40
feat: properly fetch response
Nov 27, 2023
aa9d920
feat: update tests
Nov 27, 2023
176ad48
feat: add a few tests
Nov 27, 2023
32b1a8d
fix: update test structure
Nov 27, 2023
b9d6ba7
feat: add local timeout and more tests
Nov 27, 2023
b50b7ed
refactor: cut back on redundant code
Nov 27, 2023
651151e
fix: patch tests
Nov 27, 2023
f9fa25e
fix: partial conversion to use Auth Hook structs
Nov 27, 2023
47504b0
feat: add initial Error Message
Nov 27, 2023
fa1cbde
refactor: rename some vars
Nov 27, 2023
c4470d7
chore: small comments
Nov 27, 2023
2842b23
chore: add a default message
Nov 28, 2023
2260f96
refactor: remove excess code
Nov 28, 2023
ddae946
chore: convert fetchhookname into configuration load
Nov 28, 2023
25f95f4
fix: light refactor of tests
Nov 28, 2023
7693417
Merge branch 'master' of github.com:supabase/gotrue into j0/mfa_verif…
Nov 29, 2023
af2c255
fix: add status code check
Nov 29, 2023
7b874b8
refactor: use errors
Nov 29, 2023
447de5f
refactor: shove config back to configuration
Nov 29, 2023
18ccfc6
Update internal/conf/configuration.go
J0 Nov 30, 2023
be4fd32
Update internal/hooks/auth_hooks.go
J0 Nov 30, 2023
a15631b
refactor: use error interface instead
Nov 30, 2023
8641c92
Merge branch 'j0/mfa_verification_counter_hook' of github.com:supabas…
Nov 30, 2023
2d512cb
fix: make schema constant
Nov 30, 2023
3c99bde
test: reinstate test suite
Nov 30, 2023
3699c01
fix: remove unused structs
Nov 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package api

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"

"net/url"

"github.com/aaronarduino/goqrsvg"
svg "github.com/ajstarks/svgo"
"github.com/boombuler/barcode/qr"
"github.com/gofrs/uuid"
"github.com/pquerna/otp/totp"
"github.com/supabase/gotrue/internal/hooks"
"github.com/supabase/gotrue/internal/metering"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
Expand Down Expand Up @@ -196,6 +199,42 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
})
}

func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) error {
var response []byte
switch input.(type) {
case hooks.MFAVerificationAttemptInput:
payload, err := json.Marshal(&input)
if err != nil {
panic(err)
}

if err := a.db.Transaction(func(tx *storage.Connection) error {
// We rely on Postgres timeouts to ensure the function doesn't overrun
timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout))
if terr := timeoutQuery.Exec(); terr != nil {
return terr
}
query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload)
J0 marked this conversation as resolved.
Show resolved Hide resolved
terr := query.First(&response)
if terr != nil {
return terr
}
return nil
}); err != nil {
return err
}
if err = json.Unmarshal(response, &output); err != nil {
return err
}
if output.IsError() {
return &output.(*hooks.MFAVerificationAttemptOutput).HookError
}

return nil
default:
panic("invalid extensibility point")
}
}
func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
var err error
ctx := r.Context()
Expand Down Expand Up @@ -245,7 +284,31 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID)
}

if valid := totp.Validate(params.Code, factor.Secret); !valid {
valid := totp.Validate(params.Code, factor.Secret)

if config.Hook.MFAVerificationAttempt.Enabled {
input := hooks.MFAVerificationAttemptInput{
UserID: user.ID,
FactorID: factor.ID,
Valid: valid,
}
output := &hooks.MFAVerificationAttemptOutput{}
err := a.invokeHook(ctx, input, output)
if err != nil {
return errors.New(err.Error())
}
J0 marked this conversation as resolved.
Show resolved Hide resolved

if output.Decision == hooks.MFAHookRejection {
if err := models.Logout(a.db, user.ID); err != nil {
return err
}
if output.Message == "" {
output.Message = hooks.DefaultMFAHookRejectionMessage
}
return forbiddenError(output.Message)
J0 marked this conversation as resolved.
Show resolved Hide resolved
}
}
if !valid {
return badRequestError("Invalid TOTP code entered")
}

Expand Down
153 changes: 143 additions & 10 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

Expand Down Expand Up @@ -399,7 +399,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

Expand Down Expand Up @@ -447,10 +447,9 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
return data
}

func performTestSignupAndVerify(ts *MFATestSuite, email, password string) *httptest.ResponseRecorder {

func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder {
signUpResp := signUp(ts, email, password)
resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token)
resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token, requireStatusOK)

return resp

Expand All @@ -468,9 +467,9 @@ func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), expectedCode, w.Code)
return w

}
func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int) *httptest.ResponseRecorder {

func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int, requireStatusOK bool) *httptest.ResponseRecorder {
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()

Expand All @@ -495,7 +494,9 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
require.Equal(ts.T(), http.StatusOK, y.Code)
if requireStatusOK {
require.Equal(ts.T(), http.StatusOK, y.Code)
}
return y
}

Expand All @@ -511,7 +512,7 @@ func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *h

}

func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *httptest.ResponseRecorder {
func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string, requireStatusOK bool) *httptest.ResponseRecorder {
w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, http.StatusOK)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
Expand All @@ -525,7 +526,139 @@ func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *
challengeID := challengeResp.ID

// Verify
y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK)
y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK, requireStatusOK)

return y
}

func (ts *MFATestSuite) TestVerificationHooks() {
type verificationHookTestCase struct {
desc string
enabled bool
uri string
hookFunctionSQL string
emailSuffix string
expectToken bool
expectedCode int
cleanupHookFunction string
}
cases := []verificationHookTestCase{
{
desc: "Default Success",
enabled: true,
uri: "pg-functions://postgres/auth/verification_hook",
hookFunctionSQL: `
create or replace function verification_hook(input jsonb)
returns json as $$
begin
return json_build_object('decision', 'continue');
end; $$ language plpgsql;`,
emailSuffix: "success",
expectToken: true,
expectedCode: http.StatusOK,
cleanupHookFunction: "verification_hook(input jsonb)",
},
{
desc: "Error",
enabled: true,
uri: "pg-functions://postgres/auth/test_verification_hook_error",
hookFunctionSQL: `
create or replace function test_verification_hook_error(input jsonb)
returns json as $$
begin
RAISE EXCEPTION 'Intentional Error for Testing';
end; $$ language plpgsql;`,
emailSuffix: "error",
expectToken: false,
expectedCode: http.StatusInternalServerError,
cleanupHookFunction: "test_verification_hook_error(input jsonb)",
},
{
desc: "Reject - Enabled",
enabled: true,
uri: "pg-functions://postgres/auth/verification_hook_reject",
hookFunctionSQL: `
create or replace function verification_hook_reject(input jsonb)
returns json as $$
begin
return json_build_object(
'decision', 'reject',
'message', 'authentication attempt rejected'
);
end; $$ language plpgsql;`,
emailSuffix: "reject_enabled",
expectToken: false,
expectedCode: http.StatusForbidden,
cleanupHookFunction: "verification_hook_reject(input jsonb)",
},
{
desc: "Reject - Disabled",
enabled: false,
uri: "pg-functions://postgres/auth/verification_hook_reject",
hookFunctionSQL: `
create or replace function verification_hook_reject(input jsonb)
returns json as $$
begin
return json_build_object(
'decision', 'reject',
'message', 'authentication attempt rejected'
);
end; $$ language plpgsql;`,
emailSuffix: "reject_disabled",
expectToken: true,
expectedCode: http.StatusOK,
cleanupHookFunction: "verification_hook_reject(input jsonb)",
},
{
desc: "Timeout",
enabled: true,
uri: "pg-functions://postgres/auth/test_verification_hook_timeout",
hookFunctionSQL: `
create or replace function test_verification_hook_timeout(input jsonb)
returns json as $$
begin
PERFORM pg_sleep(3);
return json_build_object(
'decision', 'continue'
);
end; $$ language plpgsql;`,
emailSuffix: "timeout",
expectToken: false,
expectedCode: http.StatusInternalServerError,
cleanupHookFunction: "test_verification_hook_timeout(input jsonb)",
},
}

for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled
ts.Config.Hook.MFAVerificationAttempt.URI = c.uri
require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.ValidateAndPopulateExtensibilityPoint())

err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
require.NoError(t, err)

email := fmt.Sprintf("testemail_%[email protected]", c.emailSuffix)
password := "testpassword"
resp := performTestSignupAndVerify(ts, email, password, c.expectToken)
require.Equal(ts.T(), c.expectedCode, resp.Code)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

if c.expectToken {
require.NotEqual(t, "", accessTokenResp.Token)
} else {
require.Equal(t, "", accessTokenResp.Token)
}

cleanupHook(ts, c.cleanupHookFunction)
})
}
}

func cleanupHook(ts *MFATestSuite, hookName string) {
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName)
err := ts.API.db.RawQuery(cleanupHookSQL).Exec()
require.NoError(ts.T(), err)

}
8 changes: 8 additions & 0 deletions internal/api/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ type Settings struct {
SmsProvider string `json:"sms_provider"`
MFAEnabled bool `json:"mfa_enabled"`
SAMLEnabled bool `json:"saml_enabled"`
HookConfiguration HookSettings `json:"hook"`
}

type HookSettings struct {
MFAVerification bool `json:"mfa"`
}

func (a *API) Settings(w http.ResponseWriter, r *http.Request) error {
Expand Down Expand Up @@ -67,6 +72,9 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error {
Phone: config.External.Phone.Enabled,
Zoom: config.External.Zoom.Enabled,
},
HookConfiguration: HookSettings{
MFAVerification: config.Hook.MFAVerificationAttempt.Enabled,
},

DisableSignup: config.DisableSignup,
MailerAutoconfirm: config.Mailer.Autoconfirm,
Expand Down
Loading
Loading