Skip to content

Commit

Permalink
feat: add mfa verification postgres hook (supabase#1314)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Proof of concept hook for MFA Verification. With this hook, developers
can introduce additional conditions around when to accept/reject an MFA
verification (e.g. log a developer out after a certain number of
attempts).

We distinguish this from the existing Webhooks implementation via
introduction of `hooks` package which will contain future Hook related
structs, constants, and utility methods.

For the most part we leverage existing Postgres capabilities - as far as
possible we will return the PostgreSQL error codes for debugging and use
Postgres in-built timeouts to ensure hte hook doesn't overrun.

## Testing

The MFA Verification Hook test suite does not guarantee accurate status
codes - the test setup (to enroll factors and create a challenge after
signup) requires some setup. It is reliant on `signUpAndVerify` which
gets the dev to AAL2 and takes time to refactor.

As such, most of the cases were manually tested in addition to the
current loose check of checking for the absence of an access token.
Further edits will be made in GMT +8 morning to properly check for the
http status codes in the tests.

Also, since `supabase_auth_admin` cannot create functions on the
`public` schema we create the functions on the `auth` schema for
testing. We typically discourage this on the Supabase platform but in
theory there should be no issue when dealing with GoTrue (the OSS
project). Will spend a short amount of time looking into alternatives
tomorrow.


## Additional Notes

Response schema checks are left out of this PR as they don't seem to
serve as much benefit for this particular extensibility point and will
probably bloat the PR a little with the introduction of a new library

---------

Co-authored-by: [email protected] <[email protected]>
Co-authored-by: Stojan Dimitrovski <[email protected]>
  • Loading branch information
3 people authored and LashaJini committed Nov 13, 2024
1 parent 4dc306c commit d0f444c
Show file tree
Hide file tree
Showing 6 changed files with 369 additions and 14 deletions.
67 changes: 65 additions & 2 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package api

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"net/http"

"net/url"

"github.com/aaronarduino/goqrsvg"
svg "github.com/ajstarks/svgo"
"github.com/boombuler/barcode/qr"
"github.com/gofrs/uuid"
"github.com/pquerna/otp/totp"
"github.com/supabase/gotrue/internal/hooks"
"github.com/supabase/gotrue/internal/metering"
"github.com/supabase/gotrue/internal/models"
"github.com/supabase/gotrue/internal/storage"
Expand Down Expand Up @@ -196,6 +199,42 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error {
})
}

func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) error {
var response []byte
switch input.(type) {
case hooks.MFAVerificationAttemptInput:
payload, err := json.Marshal(&input)
if err != nil {
panic(err)
}

if err := a.db.Transaction(func(tx *storage.Connection) error {
// We rely on Postgres timeouts to ensure the function doesn't overrun
timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout))
if terr := timeoutQuery.Exec(); terr != nil {
return terr
}
query := tx.RawQuery(fmt.Sprintf("SELECT %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload)
terr := query.First(&response)
if terr != nil {
return terr
}
return nil
}); err != nil {
return err
}
if err = json.Unmarshal(response, &output); err != nil {
return err
}
if output.IsError() {
return &output.(*hooks.MFAVerificationAttemptOutput).HookError
}

return nil
default:
panic("invalid extensibility point")
}
}
func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
var err error
ctx := r.Context()
Expand Down Expand Up @@ -245,7 +284,31 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error {
return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID)
}

if valid := totp.Validate(params.Code, factor.Secret); !valid {
valid := totp.Validate(params.Code, factor.Secret)

if config.Hook.MFAVerificationAttempt.Enabled {
input := hooks.MFAVerificationAttemptInput{
UserID: user.ID,
FactorID: factor.ID,
Valid: valid,
}
output := &hooks.MFAVerificationAttemptOutput{}
err := a.invokeHook(ctx, input, output)
if err != nil {
return errors.New(err.Error())
}

if output.Decision == hooks.MFAHookRejection {
if err := models.Logout(a.db, user.ID); err != nil {
return err
}
if output.Message == "" {
output.Message = hooks.DefaultMFAHookRejectionMessage
}
return forbiddenError(output.Message)
}
}
if !valid {
return badRequestError("Invalid TOTP code entered")
}

Expand Down
153 changes: 143 additions & 10 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

Expand Down Expand Up @@ -399,7 +399,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword)
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

Expand Down Expand Up @@ -447,10 +447,9 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
return data
}

func performTestSignupAndVerify(ts *MFATestSuite, email, password string) *httptest.ResponseRecorder {

func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder {
signUpResp := signUp(ts, email, password)
resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token)
resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token, requireStatusOK)

return resp

Expand All @@ -468,9 +467,9 @@ func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), expectedCode, w.Code)
return w

}
func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int) *httptest.ResponseRecorder {

func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int, requireStatusOK bool) *httptest.ResponseRecorder {
var verifyBuffer bytes.Buffer
y := httptest.NewRecorder()

Expand All @@ -495,7 +494,9 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token
req.Header.Set("Content-Type", "application/json")

ts.API.handler.ServeHTTP(y, req)
require.Equal(ts.T(), http.StatusOK, y.Code)
if requireStatusOK {
require.Equal(ts.T(), http.StatusOK, y.Code)
}
return y
}

Expand All @@ -511,7 +512,7 @@ func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *h

}

func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *httptest.ResponseRecorder {
func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string, requireStatusOK bool) *httptest.ResponseRecorder {
w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, http.StatusOK)
enrollResp := EnrollFactorResponse{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp))
Expand All @@ -525,7 +526,139 @@ func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *
challengeID := challengeResp.ID

// Verify
y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK)
y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK, requireStatusOK)

return y
}

func (ts *MFATestSuite) TestVerificationHooks() {
type verificationHookTestCase struct {
desc string
enabled bool
uri string
hookFunctionSQL string
emailSuffix string
expectToken bool
expectedCode int
cleanupHookFunction string
}
cases := []verificationHookTestCase{
{
desc: "Default Success",
enabled: true,
uri: "pg-functions://postgres/auth/verification_hook",
hookFunctionSQL: `
create or replace function verification_hook(input jsonb)
returns json as $$
begin
return json_build_object('decision', 'continue');
end; $$ language plpgsql;`,
emailSuffix: "success",
expectToken: true,
expectedCode: http.StatusOK,
cleanupHookFunction: "verification_hook(input jsonb)",
},
{
desc: "Error",
enabled: true,
uri: "pg-functions://postgres/auth/test_verification_hook_error",
hookFunctionSQL: `
create or replace function test_verification_hook_error(input jsonb)
returns json as $$
begin
RAISE EXCEPTION 'Intentional Error for Testing';
end; $$ language plpgsql;`,
emailSuffix: "error",
expectToken: false,
expectedCode: http.StatusInternalServerError,
cleanupHookFunction: "test_verification_hook_error(input jsonb)",
},
{
desc: "Reject - Enabled",
enabled: true,
uri: "pg-functions://postgres/auth/verification_hook_reject",
hookFunctionSQL: `
create or replace function verification_hook_reject(input jsonb)
returns json as $$
begin
return json_build_object(
'decision', 'reject',
'message', 'authentication attempt rejected'
);
end; $$ language plpgsql;`,
emailSuffix: "reject_enabled",
expectToken: false,
expectedCode: http.StatusForbidden,
cleanupHookFunction: "verification_hook_reject(input jsonb)",
},
{
desc: "Reject - Disabled",
enabled: false,
uri: "pg-functions://postgres/auth/verification_hook_reject",
hookFunctionSQL: `
create or replace function verification_hook_reject(input jsonb)
returns json as $$
begin
return json_build_object(
'decision', 'reject',
'message', 'authentication attempt rejected'
);
end; $$ language plpgsql;`,
emailSuffix: "reject_disabled",
expectToken: true,
expectedCode: http.StatusOK,
cleanupHookFunction: "verification_hook_reject(input jsonb)",
},
{
desc: "Timeout",
enabled: true,
uri: "pg-functions://postgres/auth/test_verification_hook_timeout",
hookFunctionSQL: `
create or replace function test_verification_hook_timeout(input jsonb)
returns json as $$
begin
PERFORM pg_sleep(3);
return json_build_object(
'decision', 'continue'
);
end; $$ language plpgsql;`,
emailSuffix: "timeout",
expectToken: false,
expectedCode: http.StatusInternalServerError,
cleanupHookFunction: "test_verification_hook_timeout(input jsonb)",
},
}

for _, c := range cases {
ts.T().Run(c.desc, func(t *testing.T) {
ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled
ts.Config.Hook.MFAVerificationAttempt.URI = c.uri
require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.ValidateAndPopulateExtensibilityPoint())

err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec()
require.NoError(t, err)

email := fmt.Sprintf("testemail_%[email protected]", c.emailSuffix)
password := "testpassword"
resp := performTestSignupAndVerify(ts, email, password, c.expectToken)
require.Equal(ts.T(), c.expectedCode, resp.Code)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

if c.expectToken {
require.NotEqual(t, "", accessTokenResp.Token)
} else {
require.Equal(t, "", accessTokenResp.Token)
}

cleanupHook(ts, c.cleanupHookFunction)
})
}
}

func cleanupHook(ts *MFATestSuite, hookName string) {
cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName)
err := ts.API.db.RawQuery(cleanupHookSQL).Exec()
require.NoError(ts.T(), err)

}
8 changes: 8 additions & 0 deletions internal/api/settings.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ type Settings struct {
SmsProvider string `json:"sms_provider"`
MFAEnabled bool `json:"mfa_enabled"`
SAMLEnabled bool `json:"saml_enabled"`
HookConfiguration HookSettings `json:"hook"`
}

type HookSettings struct {
MFAVerification bool `json:"mfa"`
}

func (a *API) Settings(w http.ResponseWriter, r *http.Request) error {
Expand Down Expand Up @@ -67,6 +72,9 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error {
Phone: config.External.Phone.Enabled,
Zoom: config.External.Zoom.Enabled,
},
HookConfiguration: HookSettings{
MFAVerification: config.Hook.MFAVerificationAttempt.Enabled,
},

DisableSignup: config.DisableSignup,
MailerAutoconfirm: config.Mailer.Autoconfirm,
Expand Down
Loading

0 comments on commit d0f444c

Please sign in to comment.