From 67ecb95b9647354ccc662e2a2c20a209b378dd8a Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 20 Nov 2023 21:18:45 +0800 Subject: [PATCH 01/42] feat: initial commit --- internal/api/settings.go | 5 +++++ internal/conf/configuration.go | 22 +++++++++++++++++++++- internal/conf/configuration_test.go | 4 +++- 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/internal/api/settings.go b/internal/api/settings.go index c2ab0df39a..2d4826d158 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -36,6 +36,9 @@ type Settings struct { SmsProvider string `json:"sms_provider"` MFAEnabled bool `json:"mfa_enabled"` SAMLEnabled bool `json:"saml_enabled"` + + // TODO: Remove this later. For debugging + MFAHook string `json:"mfa_hook"` } func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { @@ -67,6 +70,8 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { Phone: config.External.Phone.Enabled, Zoom: config.External.Zoom.Enabled, }, + // TODO: Remove this too. For debugging + MFAHook: config.Hook.MFA.URI, DisableSignup: config.DisableSignup, MailerAutoconfirm: config.Mailer.Autoconfirm, diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index cc5b7505df..9ec162cb6b 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -158,6 +158,7 @@ type GlobalConfiguration struct { Sms SmsProviderConfiguration `json:"sms"` DisableSignup bool `json:"disable_signup" split_words:"true"` Webhook WebhookConfig `json:"webhook" split_words:"true"` + Hook HookConfiguration `json:"hook" split_words:"true"` Security SecurityConfiguration `json:"security"` Sessions SessionsConfiguration `json:"sessions"` MFA MFAConfiguration `json:"MFA"` @@ -379,6 +380,26 @@ type WebhookConfig struct { Events []string `json:"events"` } +// Moving away from the existing HookConfig so we can get a fresh start. +type HookConfiguration struct { + // TODO (Joel): Fix the naming later + MFA ExtensibilityPointConfiguration `json:"mfa"` +} + +type ExtensibilityPointConfiguration struct { + URI string `json:"uri"` +} + +func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { + if e.URI != "" { + _, err := url.Parse(e.URI) + if err != nil { + return errors.New("hook entry should be a valid URI") + } + } + return nil +} + func (w *WebhookConfig) HasEvent(event string) bool { for _, name := range w.Events { if event == name { @@ -424,7 +445,6 @@ func LoadGlobal(filename string) (*GlobalConfiguration, error) { } config.Sms.SMSTemplate = template } - return config, nil } diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index 6931753749..5c68e39f1a 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -15,14 +15,16 @@ func TestMain(m *testing.M) { func TestGlobal(t *testing.T) { os.Setenv("GOTRUE_SITE_URL", "http://localhost:8080") - os.Setenv("GOTRUE_DB_DRIVER", "mysql") + os.Setenv("GOTRUE_DB_DRIVER", "postgres") os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") os.Setenv("GOTRUE_JWT_SECRET", "secret") os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + os.Setenv("GOTRUE_HOOK_MFA_URI", "postgres://postgres/auth/count_failed_attempts") gc, err := LoadGlobal("") require.NoError(t, err) require.NotNil(t, gc) assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) + assert.Equal(t, "postgres://postgres/auth/count_failed_attempts", gc.Hook.MFA.URI) } From 6078578286302d560a797f508b0f5c7c060a2894 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 20 Nov 2023 23:19:53 +0800 Subject: [PATCH 02/42] fix: add scaffolding --- internal/api/auth_hooks.go | 34 ++++++++++++++++++++++++++++++++++ internal/api/hooks.go | 3 --- internal/api/mfa.go | 16 +++++++++++++++- internal/conf/configuration.go | 3 +++ 4 files changed, 52 insertions(+), 4 deletions(-) create mode 100644 internal/api/auth_hooks.go diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go new file mode 100644 index 0000000000..4f8dea798b --- /dev/null +++ b/internal/api/auth_hooks.go @@ -0,0 +1,34 @@ +package api + +import ( + "github.com/supabase/gotrue/internal/conf" + "time" +) + +type AuthHook struct { + *conf.HookConfiguration + payload []byte +} + +const ( + defaultTimeout = time.Second * 2 + defaultHookRetries = 3 +) + +type HookType string + +const ( + PostgresHook HookType = "postgres" +) + +func (a *AuthHook) trigger() error { + // Parse URI object + + // switch between Postgres Hook and HTTP Hook, pass in URI + return nil +} + +func (a *AuthHook) triggerPostgresHook() error { + + return nil +} diff --git a/internal/api/hooks.go b/internal/api/hooks.go index c5378af54c..e85966e471 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -28,7 +28,6 @@ type HookEvent string const ( headerHookSignature = "x-webhook-signature" - defaultHookRetries = 3 gotrueIssuer = "gotrue" ValidateEvent = "validate" SignupEvent = "signup" @@ -36,8 +35,6 @@ const ( LoginEvent = "login" ) -var defaultTimeout = time.Second * 5 - type webhookClaims struct { jwt.StandardClaims SHA256 string `json:"sha256"` diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 57130438c0..89892e6b20 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -245,7 +245,21 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) } - if valid := totp.Validate(params.Code, factor.Secret); !valid { + valid := totp.Validate(params.Code, factor.Secret) + if config.Hook.MFA.IsEnabled() { + // payload := CreateMFAVerificationHookPayload(user_id, factor_id, valid) + // h := Hook { + // extensibilityPoint: extensibilityPoint + // event: "auth.mfa_verification", + // payload: payload + // } + // + // if resp, err := h.Trigger(); err != nil + // return errors.New("error executing hook").withError(err) + // } + return badRequestError("hook is enabled") + } + if !valid { return badRequestError("Invalid TOTP code entered") } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 9ec162cb6b..fa676eea36 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -399,6 +399,9 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { } return nil } +func (e *ExtensibilityPointConfiguration) IsEnabled() bool { + return e.URI != "" +} func (w *WebhookConfig) HasEvent(event string) bool { for _, name := range w.Events { From 51e1987d3019c89b006ae33bf9d430601250a5b2 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 20 Nov 2023 23:36:58 +0800 Subject: [PATCH 03/42] feat: generate payload --- internal/api/auth_hooks.go | 47 ++++++++++++++++++++++++++++++-------- internal/api/mfa.go | 28 +++++++++++++++-------- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 4f8dea798b..4a6a619672 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -1,31 +1,60 @@ package api import ( - "github.com/supabase/gotrue/internal/conf" + "encoding/json" "time" + + "github.com/gofrs/uuid" + "github.com/supabase/gotrue/internal/conf" +) + +type HookType string + +const ( + PostgresHook HookType = "postgres" ) type AuthHook struct { *conf.HookConfiguration - payload []byte + payload []byte + hookType HookType + event string } +// Hook Events const ( - defaultTimeout = time.Second * 2 - defaultHookRetries = 3 + MFAVerificationEvent = "auth.mfa_verfication" ) -type HookType string - const ( - PostgresHook HookType = "postgres" + defaultTimeout = time.Second * 2 + defaultHookRetries = 3 ) -func (a *AuthHook) trigger() error { +// Functions for encoding and decoding payload +func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, valid bool) ([]byte, error) { + // TODO: find a better way of encdoing so we can support HTTP hooks + payload := struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + Valid bool `json:"valid"` + }{ + UserID: user_id, + FactorID: factor_id, + Valid: valid, + } + data, err := json.Marshal(&payload) + if err != nil { + return nil, err + } + return data, nil +} + +func (a *AuthHook) Trigger() ([]byte, error) { // Parse URI object // switch between Postgres Hook and HTTP Hook, pass in URI - return nil + return nil, nil } func (a *AuthHook) triggerPostgresHook() error { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 89892e6b20..0215576151 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -247,16 +247,24 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { valid := totp.Validate(params.Code, factor.Secret) if config.Hook.MFA.IsEnabled() { - // payload := CreateMFAVerificationHookPayload(user_id, factor_id, valid) - // h := Hook { - // extensibilityPoint: extensibilityPoint - // event: "auth.mfa_verification", - // payload: payload - // } - // - // if resp, err := h.Trigger(); err != nil - // return errors.New("error executing hook").withError(err) - // } + // To allow for future cases where we don't know that the payload is going to be passed in + + payload, err := CreateMFAVerificationHookInput(user.ID, factor.ID, valid) + if err != nil { + return err + } + + h := AuthHook{ + event: MFAVerificationEvent, + payload: payload, + hookType: PostgresHook, + } + + // TODO: revert to resp and use the resp. In this MFA Verification case the resp is not used + _, err = h.Trigger() + if err != nil { + return err + } return badRequestError("hook is enabled") } if !valid { From 1c15e18ec9e3b00bd369d4aadbc7f0db755153a2 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 21 Nov 2023 01:24:25 +0800 Subject: [PATCH 04/42] feat: set up calling mechanism --- internal/api/auth_hooks.go | 70 +++++++++++++++++++++++++++++++++----- internal/api/hooks.go | 1 + internal/api/mfa.go | 2 ++ 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 4a6a619672..deaadc0d27 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -2,23 +2,30 @@ package api import ( "encoding/json" + "errors" + "net/url" "time" + "fmt" "github.com/gofrs/uuid" "github.com/supabase/gotrue/internal/conf" + "github.com/supabase/gotrue/internal/storage" + "strings" ) type HookType string const ( PostgresHook HookType = "postgres" + HTTPHook HookType = "http" ) type AuthHook struct { - *conf.HookConfiguration + *conf.ExtensibilityPointConfiguration payload []byte hookType HookType event string + db *storage.Connection } // Hook Events @@ -27,13 +34,11 @@ const ( ) const ( - defaultTimeout = time.Second * 2 - defaultHookRetries = 3 + defaultTimeout = time.Second * 2 ) // Functions for encoding and decoding payload func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, valid bool) ([]byte, error) { - // TODO: find a better way of encdoing so we can support HTTP hooks payload := struct { UserID uuid.UUID `json:"user_id"` FactorID uuid.UUID `json:"factor_id"` @@ -50,14 +55,63 @@ func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, vali return data, nil } -func (a *AuthHook) Trigger() ([]byte, error) { +func (ah *AuthHook) Trigger() ([]byte, error) { // Parse URI object + url, err := url.Parse(ah.ExtensibilityPointConfiguration.URI) + if err != nil { + return nil, err + } + // trigger appropriate type of hook + switch url.Scheme { + case string(PostgresHook): + return ah.triggerPostgresHook() + case string(HTTPHook): + return ah.triggerHTTPHook() + default: + return nil, errors.New("unsupported hook type") + } - // switch between Postgres Hook and HTTP Hook, pass in URI return nil, nil } -func (a *AuthHook) triggerPostgresHook() error { +func (ah *AuthHook) fetchHookName() (string, error) { + u, err := url.Parse(ah.ExtensibilityPointConfiguration.URI) + if err != nil { + return "", err + } + pathParts := strings.Split(u.Path, "/") + if len(pathParts) < 3 { + return "", fmt.Errorf("URI path does not contain enough parts") + } + schema := pathParts[1] + table := pathParts[2] + // TODO: maybe enforce checks on this name? + + return schema + "." + table, nil +} + +func (ah *AuthHook) triggerPostgresHook() ([]byte, error) { + // Determine Result payload and request payload + var result []byte + hookName, err := ah.fetchHookName() + if err != nil { + return nil, err + } + if err := ah.db.Transaction(func(tx *storage.Connection) error { + resp := tx.RawQuery(fmt.Sprintf("SELECT %s('%s')", hookName, ah.payload)) + terr := resp.First(result) + if terr != nil { + return terr + } + return nil + }); err != nil { + return nil, err + } + + return result, nil + +} - return nil +func (a *AuthHook) triggerHTTPHook() ([]byte, error) { + return nil, errors.New("not implemented error") } diff --git a/internal/api/hooks.go b/internal/api/hooks.go index e85966e471..80af123fba 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -33,6 +33,7 @@ const ( SignupEvent = "signup" EmailChangeEvent = "email_change" LoginEvent = "login" + defaultHookRetries = 3 ) type webhookClaims struct { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 0215576151..1c8959ace7 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -258,6 +258,8 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { event: MFAVerificationEvent, payload: payload, hookType: PostgresHook, + // TODO: find a better way to relay this + db: a.db, } // TODO: revert to resp and use the resp. In this MFA Verification case the resp is not used From c62ff8a30eda8ec5152553b323b56edfe030c520 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 21 Nov 2023 01:58:59 +0800 Subject: [PATCH 05/42] feat: add more surrounding logic --- internal/api/auth_hooks.go | 32 ++++++++++++++++++++++++++++++++ internal/api/mfa.go | 21 ++++++++++++++++++--- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index deaadc0d27..5907bbbb7c 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -37,6 +37,38 @@ const ( defaultTimeout = time.Second * 2 ) +type HookErrorResponse struct { + ErrorMessage string `json:"error_message"` + ErrorCode string `json:"error_code"` + RetryAfter bool `json:"retry_after"` +} + +type MFAVerificationHookResponse struct { + Decision string `json:"decision"` +} + +func parseErrorResponse(response []byte) (*HookErrorResponse, error) { + var errResp HookErrorResponse + err := json.Unmarshal(response, &errResp) + if err != nil { + return nil, err + } + if errResp.ErrorMessage != "" { + return &errResp, nil + } + return nil, err +} + +func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookResponse, error) { + var MFAVerificationResponse MFAVerificationHookResponse + err := json.Unmarshal(response, &MFAVerificationResponse) + if err != nil { + return nil, err + } + + return &MFAVerificationResponse, err +} + // Functions for encoding and decoding payload func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, valid bool) ([]byte, error) { payload := struct { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 1c8959ace7..bdb9cdeef4 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" + "errors" "net/url" "github.com/aaronarduino/goqrsvg" @@ -262,12 +263,26 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { db: a.db, } - // TODO: revert to resp and use the resp. In this MFA Verification case the resp is not used - _, err = h.Trigger() + resp, err := h.Trigger() if err != nil { return err } - return badRequestError("hook is enabled") + parsedErrorResponse, err := parseErrorResponse(resp) + if err != nil { + return err + } + if parsedErrorResponse != nil { + return errors.New(parsedErrorResponse.ErrorMessage) + } + // TODO: Decide what to do here + response, err := parseMFAVerificationResponse(resp) + if err != nil { + return err + } + // TODO: don't hard code this and also change to Enum + handle success case + if response.Decision == "reject" { + return errors.New("has made 5 unsuccssful verification attempts") + } } if !valid { return badRequestError("Invalid TOTP code entered") From fa1122a0068b11a6331d87f6cf4f4d9535480947 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 21 Nov 2023 11:45:50 +0800 Subject: [PATCH 06/42] fix: reinstate relevant constants --- internal/api/auth_hooks.go | 5 ----- internal/api/hooks.go | 2 ++ internal/api/mfa.go | 2 +- internal/conf/configuration.go | 6 ++---- internal/conf/configuration_test.go | 4 ++-- internal/models/linking_test.go | 8 ++++---- 6 files changed, 11 insertions(+), 16 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 5907bbbb7c..93a392d79c 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "net/url" - "time" "fmt" "github.com/gofrs/uuid" @@ -33,10 +32,6 @@ const ( MFAVerificationEvent = "auth.mfa_verfication" ) -const ( - defaultTimeout = time.Second * 2 -) - type HookErrorResponse struct { ErrorMessage string `json:"error_message"` ErrorCode string `json:"error_code"` diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 80af123fba..a46343b39c 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -41,6 +41,8 @@ type webhookClaims struct { SHA256 string `json:"sha256"` } +var defaultTimeout = time.Second * 5 + type Webhook struct { *conf.WebhookConfig diff --git a/internal/api/mfa.go b/internal/api/mfa.go index bdb9cdeef4..b4a282183b 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -247,7 +247,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } valid := totp.Validate(params.Code, factor.Secret) - if config.Hook.MFA.IsEnabled() { + if config.Hook.MFA.Enabled { // To allow for future cases where we don't know that the payload is going to be passed in payload, err := CreateMFAVerificationHookInput(user.ID, factor.ID, valid) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index fa676eea36..04763f95f8 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -387,7 +387,8 @@ type HookConfiguration struct { } type ExtensibilityPointConfiguration struct { - URI string `json:"uri"` + URI string `json:"uri"` + Enabled bool `json:"true"` } func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { @@ -399,9 +400,6 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { } return nil } -func (e *ExtensibilityPointConfiguration) IsEnabled() bool { - return e.URI != "" -} func (w *WebhookConfig) HasEvent(event string) bool { for _, name := range w.Events { diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index 5c68e39f1a..bd8d4c3676 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -21,10 +21,10 @@ func TestGlobal(t *testing.T) { os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") os.Setenv("GOTRUE_JWT_SECRET", "secret") os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") - os.Setenv("GOTRUE_HOOK_MFA_URI", "postgres://postgres/auth/count_failed_attempts") + os.Setenv("GOTRUE_HOOK_MFA_URI", "pg-functions://postgres/auth/count_failed_attempts") gc, err := LoadGlobal("") require.NoError(t, err) require.NotNil(t, gc) assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) - assert.Equal(t, "postgres://postgres/auth/count_failed_attempts", gc.Hook.MFA.URI) + assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFA.URI) } diff --git a/internal/models/linking_test.go b/internal/models/linking_test.go index d2e46944c6..5c68ac4327 100644 --- a/internal/models/linking_test.go +++ b/internal/models/linking_test.go @@ -81,7 +81,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when the email doesn't exist in the system -- conventional provider decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "other@example.com", Verified: true, Primary: true, @@ -93,7 +93,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when looking for an email that doesn't exist in the SSO linking domain decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "other@samltest.id", Verified: true, Primary: true, @@ -116,7 +116,7 @@ func (ts *AccountLinkingTestSuite) TestAccountExists() { require.NoError(ts.T(), ts.db.Create(identityA)) decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "test@example.com", Verified: true, Primary: true, @@ -299,7 +299,7 @@ func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { // identities in the same "default" linking domain with the same email // address pointing to two different user accounts decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "test@example.com", Verified: true, Primary: true, From da86cdeeb20668e225651d460247a4709394d5ba Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 21 Nov 2023 14:15:56 +0800 Subject: [PATCH 07/42] feat: add minor validation and cleanup --- internal/api/auth_hooks.go | 35 ++++++++++++++++++++++++++-------- internal/api/mfa.go | 15 ++++++--------- internal/conf/configuration.go | 5 ++++- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 93a392d79c..576fbc7411 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -9,6 +9,7 @@ import ( "github.com/gofrs/uuid" "github.com/supabase/gotrue/internal/conf" "github.com/supabase/gotrue/internal/storage" + "regexp" "strings" ) @@ -19,6 +20,11 @@ const ( HTTPHook HookType = "http" ) +const ( + MFAHookRejection = "reject" + MFAHookContinue = "continue" +) + type AuthHook struct { *conf.ExtensibilityPointConfiguration payload []byte @@ -77,7 +83,7 @@ func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, vali } data, err := json.Marshal(&payload) if err != nil { - return nil, err + panic(err) } return data, nil } @@ -92,8 +98,6 @@ func (ah *AuthHook) Trigger() ([]byte, error) { switch url.Scheme { case string(PostgresHook): return ah.triggerPostgresHook() - case string(HTTPHook): - return ah.triggerHTTPHook() default: return nil, errors.New("unsupported hook type") } @@ -102,6 +106,13 @@ func (ah *AuthHook) Trigger() ([]byte, error) { } func (ah *AuthHook) fetchHookName() (string, error) { + // specification for Postgres names + regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` + re, err := regexp.Compile(regExp) + if err != nil { + return "", err + } + u, err := url.Parse(ah.ExtensibilityPointConfiguration.URI) if err != nil { return "", err @@ -112,7 +123,13 @@ func (ah *AuthHook) fetchHookName() (string, error) { } schema := pathParts[1] table := pathParts[2] - // TODO: maybe enforce checks on this name? + // Validate schema and table names + if !re.MatchString(schema) { + return "", fmt.Errorf("invalid schema name: %s", schema) + } + if !re.MatchString(table) { + return "", fmt.Errorf("invalid table name: %s", table) + } return schema + "." + table, nil } @@ -134,11 +151,13 @@ func (ah *AuthHook) triggerPostgresHook() ([]byte, error) { }); err != nil { return nil, err } + if parsedErrorResponse, err := parseErrorResponse(result); err != nil { + if parsedErrorResponse != nil { + return nil, errors.New(parsedErrorResponse.ErrorMessage) + } + return nil, err + } return result, nil } - -func (a *AuthHook) triggerHTTPHook() ([]byte, error) { - return nil, errors.New("not implemented error") -} diff --git a/internal/api/mfa.go b/internal/api/mfa.go index b4a282183b..fd47a9aad7 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -267,20 +267,17 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - parsedErrorResponse, err := parseErrorResponse(resp) - if err != nil { - return err - } - if parsedErrorResponse != nil { - return errors.New(parsedErrorResponse.ErrorMessage) - } + // TODO: Decide what to do here response, err := parseMFAVerificationResponse(resp) if err != nil { return err } - // TODO: don't hard code this and also change to Enum + handle success case - if response.Decision == "reject" { + + if response.Decision == MFAHookRejection { + if err := models.Logout(a.db, user.ID); err != nil { + return err + } return errors.New("has made 5 unsuccssful verification attempts") } } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 04763f95f8..e5a75ca8c3 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -393,10 +393,13 @@ type ExtensibilityPointConfiguration struct { func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { if e.URI != "" { - _, err := url.Parse(e.URI) + u, err := url.Parse(e.URI) if err != nil { return errors.New("hook entry should be a valid URI") } + if pathParts := strings.Split(u.Path, "/"); len(pathParts) < 3 { + return fmt.Errorf("URI path does not contain enough parts") + } } return nil } From 185b1d7512bce5c8d3c0725cd8bd93c413221533 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 21 Nov 2023 14:37:49 +0800 Subject: [PATCH 08/42] feat: add hook configuration to settings --- internal/api/auth_hooks.go | 2 -- internal/api/mfa.go | 2 -- internal/api/settings.go | 11 +++++++---- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 576fbc7411..35dc6a0880 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -101,8 +101,6 @@ func (ah *AuthHook) Trigger() ([]byte, error) { default: return nil, errors.New("unsupported hook type") } - - return nil, nil } func (ah *AuthHook) fetchHookName() (string, error) { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index fd47a9aad7..678cc8f09d 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -248,8 +248,6 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { valid := totp.Validate(params.Code, factor.Secret) if config.Hook.MFA.Enabled { - // To allow for future cases where we don't know that the payload is going to be passed in - payload, err := CreateMFAVerificationHookInput(user.ID, factor.ID, valid) if err != nil { return err diff --git a/internal/api/settings.go b/internal/api/settings.go index 2d4826d158..9ccb58fc91 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -36,9 +36,11 @@ type Settings struct { SmsProvider string `json:"sms_provider"` MFAEnabled bool `json:"mfa_enabled"` SAMLEnabled bool `json:"saml_enabled"` + HookConfiguration HookSettings `json:"hook"` +} - // TODO: Remove this later. For debugging - MFAHook string `json:"mfa_hook"` +type HookSettings struct { + MFAVerification bool `json:"mfa"` } func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { @@ -70,8 +72,9 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { Phone: config.External.Phone.Enabled, Zoom: config.External.Zoom.Enabled, }, - // TODO: Remove this too. For debugging - MFAHook: config.Hook.MFA.URI, + HookConfiguration: HookSettings{ + MFAVerification: config.Hook.MFA.Enabled, + }, DisableSignup: config.DisableSignup, MailerAutoconfirm: config.Mailer.Autoconfirm, From 6676552aebd1ef7d90442e5f49b57e1661fe887f Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 22 Nov 2023 16:24:25 +0800 Subject: [PATCH 09/42] feat: update naming conventions --- internal/api/auth_hooks.go | 59 ++++++++++++++++++++------------------ internal/api/mfa.go | 11 +++++-- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/internal/api/auth_hooks.go b/internal/api/auth_hooks.go index 35dc6a0880..3b82b970b5 100644 --- a/internal/api/auth_hooks.go +++ b/internal/api/auth_hooks.go @@ -16,7 +16,7 @@ import ( type HookType string const ( - PostgresHook HookType = "postgres" + PostgresHook HookType = "pg-functions" HTTPHook HookType = "http" ) @@ -26,11 +26,11 @@ const ( ) type AuthHook struct { - *conf.ExtensibilityPointConfiguration - payload []byte - hookType HookType - event string - db *storage.Connection + ExtensibilityPointConfiguration conf.ExtensibilityPointConfiguration + payload []byte + hookType HookType + event string + db *storage.Connection } // Hook Events @@ -44,10 +44,6 @@ type HookErrorResponse struct { RetryAfter bool `json:"retry_after"` } -type MFAVerificationHookResponse struct { - Decision string `json:"decision"` -} - func parseErrorResponse(response []byte) (*HookErrorResponse, error) { var errResp HookErrorResponse err := json.Unmarshal(response, &errResp) @@ -60,8 +56,8 @@ func parseErrorResponse(response []byte) (*HookErrorResponse, error) { return nil, err } -func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookResponse, error) { - var MFAVerificationResponse MFAVerificationHookResponse +func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookOutput, error) { + var MFAVerificationResponse MFAVerificationHookOutput err := json.Unmarshal(response, &MFAVerificationResponse) if err != nil { return nil, err @@ -70,13 +66,18 @@ func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookResponse return &MFAVerificationResponse, err } +type MFAVerificationHookInput struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + Valid bool `json:"valid"` +} +type MFAVerificationHookOutput struct { + Decision string `json:"decision"` +} + // Functions for encoding and decoding payload func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, valid bool) ([]byte, error) { - payload := struct { - UserID uuid.UUID `json:"user_id"` - FactorID uuid.UUID `json:"factor_id"` - Valid bool `json:"valid"` - }{ + payload := MFAVerificationHookInput{ UserID: user_id, FactorID: factor_id, Valid: valid, @@ -134,14 +135,15 @@ func (ah *AuthHook) fetchHookName() (string, error) { func (ah *AuthHook) triggerPostgresHook() ([]byte, error) { // Determine Result payload and request payload - var result []byte + var hookResponse []byte hookName, err := ah.fetchHookName() if err != nil { return nil, err } if err := ah.db.Transaction(func(tx *storage.Connection) error { - resp := tx.RawQuery(fmt.Sprintf("SELECT %s('%s')", hookName, ah.payload)) - terr := resp.First(result) + // TODO: add some sort of logging here so that we track that the function is called + query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(ah.payload)) + terr := query.First(&hookResponse) if terr != nil { return terr } @@ -149,13 +151,14 @@ func (ah *AuthHook) triggerPostgresHook() ([]byte, error) { }); err != nil { return nil, err } - if parsedErrorResponse, err := parseErrorResponse(result); err != nil { - if parsedErrorResponse != nil { - return nil, errors.New(parsedErrorResponse.ErrorMessage) - } - return nil, err - } - - return result, nil + // TODO: Check if it's an error response + // if errorResponse, err := parseErrorResponse(hookResponse); err != nil { + // if errorResponse != nil { + // return nil, errors.New(errorResponse.ErrorMessage) + // } + // return nil, err + // } + + return hookResponse, nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 678cc8f09d..5f6cea32f3 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -245,6 +245,9 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) } + type HookResponse struct { + Message string + } valid := totp.Validate(params.Code, factor.Secret) if config.Hook.MFA.Enabled { @@ -254,12 +257,14 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } h := AuthHook{ - event: MFAVerificationEvent, - payload: payload, - hookType: PostgresHook, + ExtensibilityPointConfiguration: config.Hook.MFA, + event: MFAVerificationEvent, + payload: payload, + hookType: PostgresHook, // TODO: find a better way to relay this db: a.db, } + // Log that we're calling the function resp, err := h.Trigger() if err != nil { From 949354330784d5c8d80c74aa27c7085cc3ac58b4 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 00:30:14 +0800 Subject: [PATCH 10/42] feat:refactor to do w/o abstraction --- internal/api/errors.go | 1 + internal/api/mfa.go | 60 ++++++++++-------- internal/api/mfa_test.go | 4 ++ internal/{api => hooks}/auth_hooks.go | 91 ++++++++++----------------- 4 files changed, 73 insertions(+), 83 deletions(-) rename internal/{api => hooks}/auth_hooks.go (58%) diff --git a/internal/api/errors.go b/internal/api/errors.go index 8efdbe811e..b223ada9c8 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -210,6 +210,7 @@ func otpError(err string, description string) *OTPError { return &OTPError{Err: err, Description: description} } + // Recoverer is a middleware that recovers from panics, logs the panic (and a // backtrace), and returns a HTTP 500 (Internal Server Error) status if // possible. Recoverer prints a request ID if one is provided. diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 5f6cea32f3..19934803ef 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -4,10 +4,11 @@ import ( "bytes" "encoding/json" "net/http" + "context" - "errors" "net/url" + "github.com/pkg/errors" "github.com/aaronarduino/goqrsvg" svg "github.com/ajstarks/svgo" "github.com/boombuler/barcode/qr" @@ -17,6 +18,7 @@ import ( "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" "github.com/supabase/gotrue/internal/utilities" + "github.com/supabase/gotrue/internal/hooks" ) const DefaultQRSize = 3 @@ -197,6 +199,17 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } +func (a *API) invokeHook(ctx context.Context, input any, output any) (error) { + switch input.(type) { + case hooks.MFAVerificationAttemptInput: + // Check for hook type (e.g. postgres/http) here + return nil + default: + return errors.New("invalid hook extensibility point") + // trigger the MFA verification hook + + } +} func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { var err error ctx := r.Context() @@ -245,43 +258,38 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) } - type HookResponse struct { - Message string - } valid := totp.Validate(params.Code, factor.Secret) if config.Hook.MFA.Enabled { - payload, err := CreateMFAVerificationHookInput(user.ID, factor.ID, valid) - if err != nil { - return err - } - - h := AuthHook{ - ExtensibilityPointConfiguration: config.Hook.MFA, - event: MFAVerificationEvent, - payload: payload, - hookType: PostgresHook, - // TODO: find a better way to relay this - db: a.db, + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + Valid: valid, } - // Log that we're calling the function - - resp, err := h.Trigger() + output := hooks.MFAVerificationAttemptOutput{} + payload, err := json.Marshal(&input) if err != nil { - return err + panic(err) } + // Log that we're calling the function - // TODO: Decide what to do here - response, err := parseMFAVerificationResponse(resp) - if err != nil { + if err := a.invokeHook(ctx, payload, output); err != nil { return err } - - if response.Decision == MFAHookRejection { + // + // TODO: Move this into invokeHook + // response, err := hooks.ParseMFAVerificationResponse(output) + // if err != nil { + // return err + // } + + if output.Decision == hooks.MFAHookRejection { if err := models.Logout(a.db, user.ID); err != nil { return err } - return errors.New("has made 5 unsuccssful verification attempts") + // TODO: remove this and reinstate line below + return forbiddenError("invalid") + // return forbiddenError(response.Message) } } if !valid { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index e3686c9a46..a254fa93d5 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -521,3 +521,7 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR require.NoError(ts.T(), json.NewDecoder(y.Body).Decode(&verifyResp)) return verifyResp } + + +func (ts *MFATestSuite) TestHookVerification() { +} diff --git a/internal/api/auth_hooks.go b/internal/hooks/auth_hooks.go similarity index 58% rename from internal/api/auth_hooks.go rename to internal/hooks/auth_hooks.go index 3b82b970b5..ff7269370c 100644 --- a/internal/api/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -1,8 +1,7 @@ -package api +package hooks import ( "encoding/json" - "errors" "net/url" "fmt" @@ -27,37 +26,53 @@ const ( type AuthHook struct { ExtensibilityPointConfiguration conf.ExtensibilityPointConfiguration - payload []byte - hookType HookType - event string - db *storage.Connection + Payload []byte + HookType HookType + Event string + DB *storage.Connection +} + +type MFAVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + Valid bool `json:"valid"` +} + +type MFAVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` +} + +// AuthHookError is an error with a message and an HTTP status code. +type AuthHookError struct { + Code int `json:"code"` + Message string `json:"msg"` + ErrorID string `json:"error_id,omitempty"` } // Hook Events const ( - MFAVerificationEvent = "auth.mfa_verfication" + MFAVerificationAttempt = "auth.mfa_verfication" ) type HookErrorResponse struct { - ErrorMessage string `json:"error_message"` - ErrorCode string `json:"error_code"` - RetryAfter bool `json:"retry_after"` + AuthHookError } -func parseErrorResponse(response []byte) (*HookErrorResponse, error) { +func ParseErrorResponse(response []byte) (*HookErrorResponse, error) { var errResp HookErrorResponse err := json.Unmarshal(response, &errResp) if err != nil { return nil, err } - if errResp.ErrorMessage != "" { + if errResp.Message != "" { return &errResp, nil } return nil, err } -func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookOutput, error) { - var MFAVerificationResponse MFAVerificationHookOutput +func ParseMFAVerificationResponse(response []byte) (*MFAVerificationAttemptOutput, error) { + var MFAVerificationResponse MFAVerificationAttemptOutput err := json.Unmarshal(response, &MFAVerificationResponse) if err != nil { return nil, err @@ -66,45 +81,7 @@ func parseMFAVerificationResponse(response []byte) (*MFAVerificationHookOutput, return &MFAVerificationResponse, err } -type MFAVerificationHookInput struct { - UserID uuid.UUID `json:"user_id"` - FactorID uuid.UUID `json:"factor_id"` - Valid bool `json:"valid"` -} -type MFAVerificationHookOutput struct { - Decision string `json:"decision"` -} - -// Functions for encoding and decoding payload -func CreateMFAVerificationHookInput(user_id uuid.UUID, factor_id uuid.UUID, valid bool) ([]byte, error) { - payload := MFAVerificationHookInput{ - UserID: user_id, - FactorID: factor_id, - Valid: valid, - } - data, err := json.Marshal(&payload) - if err != nil { - panic(err) - } - return data, nil -} - -func (ah *AuthHook) Trigger() ([]byte, error) { - // Parse URI object - url, err := url.Parse(ah.ExtensibilityPointConfiguration.URI) - if err != nil { - return nil, err - } - // trigger appropriate type of hook - switch url.Scheme { - case string(PostgresHook): - return ah.triggerPostgresHook() - default: - return nil, errors.New("unsupported hook type") - } -} - -func (ah *AuthHook) fetchHookName() (string, error) { +func (ah *AuthHook) FetchHookName() (string, error) { // specification for Postgres names regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` re, err := regexp.Compile(regExp) @@ -133,16 +110,16 @@ func (ah *AuthHook) fetchHookName() (string, error) { return schema + "." + table, nil } -func (ah *AuthHook) triggerPostgresHook() ([]byte, error) { +func (ah *AuthHook) TriggerPostgresHook() ([]byte, error) { // Determine Result payload and request payload var hookResponse []byte - hookName, err := ah.fetchHookName() + hookName, err := ah.FetchHookName() if err != nil { return nil, err } - if err := ah.db.Transaction(func(tx *storage.Connection) error { + if err := ah.DB.Transaction(func(tx *storage.Connection) error { // TODO: add some sort of logging here so that we track that the function is called - query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(ah.payload)) + query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(ah.Payload)) terr := query.First(&hookResponse) if terr != nil { return terr From 4b0857084adebef4002e9abd48d4f5bb903ac706 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 00:42:38 +0800 Subject: [PATCH 11/42] fix: remove now redundant methods --- internal/api/errors.go | 1 - internal/api/mfa.go | 42 ++++++++++++++++++++-------- internal/api/mfa_test.go | 1 - internal/hooks/auth_hooks.go | 53 ++---------------------------------- 4 files changed, 32 insertions(+), 65 deletions(-) diff --git a/internal/api/errors.go b/internal/api/errors.go index b223ada9c8..8efdbe811e 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -210,7 +210,6 @@ func otpError(err string, description string) *OTPError { return &OTPError{Err: err, Description: description} } - // Recoverer is a middleware that recovers from panics, logs the panic (and a // backtrace), and returns a HTTP 500 (Internal Server Error) status if // possible. Recoverer prints a request ID if one is provided. diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 19934803ef..6e1301cc90 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -2,23 +2,24 @@ package api import ( "bytes" + "context" "encoding/json" + "fmt" "net/http" - "context" "net/url" - "github.com/pkg/errors" "github.com/aaronarduino/goqrsvg" svg "github.com/ajstarks/svgo" "github.com/boombuler/barcode/qr" "github.com/gofrs/uuid" + "github.com/pkg/errors" "github.com/pquerna/otp/totp" + "github.com/supabase/gotrue/internal/hooks" "github.com/supabase/gotrue/internal/metering" "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" "github.com/supabase/gotrue/internal/utilities" - "github.com/supabase/gotrue/internal/hooks" ) const DefaultQRSize = 3 @@ -199,16 +200,33 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) invokeHook(ctx context.Context, input any, output any) (error) { - switch input.(type) { - case hooks.MFAVerificationAttemptInput: - // Check for hook type (e.g. postgres/http) here - return nil - default: - return errors.New("invalid hook extensibility point") - // trigger the MFA verification hook +func (a *API) invokeHook(ctx context.Context, input any, output any) error { + switch input.(type) { + case hooks.MFAVerificationAttemptInput: + var hookResponse []byte + hookName, err := hooks.FetchHookName(a.config.Hook.MFA) + if err != nil { + return err + } + if err := a.db.Transaction(func(tx *storage.Connection) error { + // TODO: add some sort of logging here so that we track that the function is called + query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(input.([]byte))) + terr := query.First(&hookResponse) + if terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + + // Check for hook type (e.g. postgres/http) here + return nil + default: + return errors.New("invalid hook extensibility point") + // trigger the MFA verification hook - } + } } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { var err error diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index a254fa93d5..346498fa49 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -522,6 +522,5 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR return verifyResp } - func (ts *MFATestSuite) TestHookVerification() { } diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index ff7269370c..30ac06aa88 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -7,7 +7,6 @@ import ( "fmt" "github.com/gofrs/uuid" "github.com/supabase/gotrue/internal/conf" - "github.com/supabase/gotrue/internal/storage" "regexp" "strings" ) @@ -24,14 +23,6 @@ const ( MFAHookContinue = "continue" ) -type AuthHook struct { - ExtensibilityPointConfiguration conf.ExtensibilityPointConfiguration - Payload []byte - HookType HookType - Event string - DB *storage.Connection -} - type MFAVerificationAttemptInput struct { UserID uuid.UUID `json:"user_id"` FactorID uuid.UUID `json:"factor_id"` @@ -71,17 +62,7 @@ func ParseErrorResponse(response []byte) (*HookErrorResponse, error) { return nil, err } -func ParseMFAVerificationResponse(response []byte) (*MFAVerificationAttemptOutput, error) { - var MFAVerificationResponse MFAVerificationAttemptOutput - err := json.Unmarshal(response, &MFAVerificationResponse) - if err != nil { - return nil, err - } - - return &MFAVerificationResponse, err -} - -func (ah *AuthHook) FetchHookName() (string, error) { +func FetchHookName(ep conf.ExtensibilityPointConfiguration) (string, error) { // specification for Postgres names regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` re, err := regexp.Compile(regExp) @@ -89,7 +70,7 @@ func (ah *AuthHook) FetchHookName() (string, error) { return "", err } - u, err := url.Parse(ah.ExtensibilityPointConfiguration.URI) + u, err := url.Parse(ep.URI) if err != nil { return "", err } @@ -109,33 +90,3 @@ func (ah *AuthHook) FetchHookName() (string, error) { return schema + "." + table, nil } - -func (ah *AuthHook) TriggerPostgresHook() ([]byte, error) { - // Determine Result payload and request payload - var hookResponse []byte - hookName, err := ah.FetchHookName() - if err != nil { - return nil, err - } - if err := ah.DB.Transaction(func(tx *storage.Connection) error { - // TODO: add some sort of logging here so that we track that the function is called - query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(ah.Payload)) - terr := query.First(&hookResponse) - if terr != nil { - return terr - } - return nil - }); err != nil { - return nil, err - } - // TODO: Check if it's an error response - // if errorResponse, err := parseErrorResponse(hookResponse); err != nil { - // if errorResponse != nil { - // return nil, errors.New(errorResponse.ErrorMessage) - // } - // return nil, err - // } - - return hookResponse, nil - -} From 78e184c9bcb24e13cda49fb50f1a3cd8736642a1 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 01:51:41 +0800 Subject: [PATCH 12/42] fix: add some logging --- internal/api/mfa.go | 33 ++++++++++++++---------------- internal/hooks/auth_hooks.go | 5 +++++ internal/models/audit_log_entry.go | 3 +++ 3 files changed, 23 insertions(+), 18 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6e1301cc90..d8503064db 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -203,15 +203,14 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { func (a *API) invokeHook(ctx context.Context, input any, output any) error { switch input.(type) { case hooks.MFAVerificationAttemptInput: - var hookResponse []byte + var response []byte hookName, err := hooks.FetchHookName(a.config.Hook.MFA) if err != nil { return err } if err := a.db.Transaction(func(tx *storage.Connection) error { - // TODO: add some sort of logging here so that we track that the function is called query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(input.([]byte))) - terr := query.First(&hookResponse) + terr := query.First(&response) if terr != nil { return terr } @@ -219,13 +218,14 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { }); err != nil { return err } - - // Check for hook type (e.g. postgres/http) here + hookResponseOrError := hooks.HookErrorResponse{} + err = json.Unmarshal(response, &hookResponseOrError) + if err == nil && hookResponseOrError.IsError() { + return err + } return nil default: return errors.New("invalid hook extensibility point") - // trigger the MFA verification hook - } } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { @@ -289,25 +289,22 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { panic(err) } - // Log that we're calling the function if err := a.invokeHook(ctx, payload, output); err != nil { return err } - // - // TODO: Move this into invokeHook - // response, err := hooks.ParseMFAVerificationResponse(output) - // if err != nil { - // return err - // } - + if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ + // TODO: include extensibility point name + "factor_id": factor.ID, + "URI": config.Hook.MFA.URI, + }); terr != nil { + return terr + } if output.Decision == hooks.MFAHookRejection { if err := models.Logout(a.db, user.ID); err != nil { return err } - // TODO: remove this and reinstate line below - return forbiddenError("invalid") - // return forbiddenError(response.Message) + return forbiddenError(output.Message) } } if !valid { diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 30ac06aa88..ff76e28b5a 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -46,10 +46,15 @@ const ( MFAVerificationAttempt = "auth.mfa_verfication" ) +// TODO: Give this a more proper name type HookErrorResponse struct { AuthHookError } +func (hookError *HookErrorResponse) IsError() bool { + return hookError.Message != "" +} + func ParseErrorResponse(response []byte) (*HookErrorResponse, error) { var errResp HookErrorResponse err := json.Unmarshal(response, &errResp) diff --git a/internal/models/audit_log_entry.go b/internal/models/audit_log_entry.go index 98223038b8..a192200ae2 100644 --- a/internal/models/audit_log_entry.go +++ b/internal/models/audit_log_entry.go @@ -40,6 +40,7 @@ const ( DeleteRecoveryCodesAction AuditAction = "recovery_codes_deleted" UpdateFactorAction AuditAction = "factor_updated" MFACodeLoginAction AuditAction = "mfa_code_login" + InvokeAuthHookAction AuditAction = "auth_hook_invoked" account auditLogType = "account" team auditLogType = "team" @@ -47,6 +48,7 @@ const ( user auditLogType = "user" factor auditLogType = "factor" recoveryCodes auditLogType = "recovery_codes" + authHook auditLogType = "auth_hook" ) var ActionLogTypeMap = map[AuditAction]auditLogType{ @@ -72,6 +74,7 @@ var ActionLogTypeMap = map[AuditAction]auditLogType{ UpdateFactorAction: factor, MFACodeLoginAction: factor, DeleteRecoveryCodesAction: recoveryCodes, + InvokeAuthHookAction: authHook, } // AuditLogEntry is the database model for audit log entries. From f3ae14b0f4973a15959d728feae97d63dd4f70a0 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 01:52:56 +0800 Subject: [PATCH 13/42] fix: remove unused code --- internal/hooks/auth_hooks.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index ff76e28b5a..cc7270c6df 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -1,7 +1,6 @@ package hooks import ( - "encoding/json" "net/url" "fmt" @@ -55,18 +54,6 @@ func (hookError *HookErrorResponse) IsError() bool { return hookError.Message != "" } -func ParseErrorResponse(response []byte) (*HookErrorResponse, error) { - var errResp HookErrorResponse - err := json.Unmarshal(response, &errResp) - if err != nil { - return nil, err - } - if errResp.Message != "" { - return &errResp, nil - } - return nil, err -} - func FetchHookName(ep conf.ExtensibilityPointConfiguration) (string, error) { // specification for Postgres names regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` From dc9d6daaf539db9592f0d951531e6d89afff8fd2 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 01:55:06 +0800 Subject: [PATCH 14/42] fix: reset unused files --- internal/api/hooks.go | 6 +++--- internal/models/linking_test.go | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index a46343b39c..c5378af54c 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -28,21 +28,21 @@ type HookEvent string const ( headerHookSignature = "x-webhook-signature" + defaultHookRetries = 3 gotrueIssuer = "gotrue" ValidateEvent = "validate" SignupEvent = "signup" EmailChangeEvent = "email_change" LoginEvent = "login" - defaultHookRetries = 3 ) +var defaultTimeout = time.Second * 5 + type webhookClaims struct { jwt.StandardClaims SHA256 string `json:"sha256"` } -var defaultTimeout = time.Second * 5 - type Webhook struct { *conf.WebhookConfig diff --git a/internal/models/linking_test.go b/internal/models/linking_test.go index 5c68ac4327..d2e46944c6 100644 --- a/internal/models/linking_test.go +++ b/internal/models/linking_test.go @@ -81,7 +81,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when the email doesn't exist in the system -- conventional provider decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - { + provider.Email{ Email: "other@example.com", Verified: true, Primary: true, @@ -93,7 +93,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when looking for an email that doesn't exist in the SSO linking domain decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - { + provider.Email{ Email: "other@samltest.id", Verified: true, Primary: true, @@ -116,7 +116,7 @@ func (ts *AccountLinkingTestSuite) TestAccountExists() { require.NoError(ts.T(), ts.db.Create(identityA)) decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - { + provider.Email{ Email: "test@example.com", Verified: true, Primary: true, @@ -299,7 +299,7 @@ func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { // identities in the same "default" linking domain with the same email // address pointing to two different user accounts decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - { + provider.Email{ Email: "test@example.com", Verified: true, Primary: true, From 1cd6effdde52615281a579b1a428dce3b84e0fa5 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Sat, 25 Nov 2023 02:04:18 +0800 Subject: [PATCH 15/42] feat: add stubs, revert unneeded changes --- internal/api/mfa_test.go | 21 ++++++++++++++++++++- internal/hooks/auth_hooks_test.go | 23 +++++++++++++++++++++++ internal/models/linking_test.go | 8 ++++---- 3 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 internal/hooks/auth_hooks_test.go diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 346498fa49..0d67287c32 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -522,5 +522,24 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR return verifyResp } -func (ts *MFATestSuite) TestHookVerification() { +func (ts *MFATestSuite) TestVerificationHookSuccess() { + // TODO +} + +func (ts *MFATestSuite) TestVerificationHookReject() { + // TODO + +} +func (ts *MFATestSuite) TestVerificationHookError() { + // TODO + +} + +func (ts *MFATestSuite) TestVerificationHookTimeout() { + // TODO +} + +func (ts *MFATestSuite) TestVerificationHookDisabled() { + // TODO + } diff --git a/internal/hooks/auth_hooks_test.go b/internal/hooks/auth_hooks_test.go new file mode 100644 index 0000000000..b797eea0e6 --- /dev/null +++ b/internal/hooks/auth_hooks_test.go @@ -0,0 +1,23 @@ +package hooks + +import ( + "github.com/stretchr/testify/suite" + "testing" +) + +type HookTestSuite struct { + suite.Suite +} + +func TestHooks(t *testing.T) { + ts := &HookTestSuite{} + suite.Run(t, ts) +} +func (ts *HookTestSuite) SetupTest() { + // TODO + +} + +func (ts *HookTestSuite) TestFetchHookName() { + // TODO +} diff --git a/internal/models/linking_test.go b/internal/models/linking_test.go index d2e46944c6..5c68ac4327 100644 --- a/internal/models/linking_test.go +++ b/internal/models/linking_test.go @@ -81,7 +81,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when the email doesn't exist in the system -- conventional provider decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "other@example.com", Verified: true, Primary: true, @@ -93,7 +93,7 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { // when looking for an email that doesn't exist in the SSO linking domain decision, err = DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "other@samltest.id", Verified: true, Primary: true, @@ -116,7 +116,7 @@ func (ts *AccountLinkingTestSuite) TestAccountExists() { require.NoError(ts.T(), ts.db.Create(identityA)) decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "test@example.com", Verified: true, Primary: true, @@ -299,7 +299,7 @@ func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { // identities in the same "default" linking domain with the same email // address pointing to two different user accounts decision, err := DetermineAccountLinking(ts.db, ts.config, []provider.Email{ - provider.Email{ + { Email: "test@example.com", Verified: true, Primary: true, From 8a3ec9c60458973ac9a55cb8e728a1f0cd1ea3fb Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 12:04:49 +0800 Subject: [PATCH 16/42] refactor: rename hook ep --- internal/api/mfa.go | 6 +++--- internal/api/settings.go | 2 +- internal/conf/configuration.go | 3 +-- internal/conf/configuration_test.go | 4 ++-- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index d8503064db..1d44c32416 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -204,7 +204,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { switch input.(type) { case hooks.MFAVerificationAttemptInput: var response []byte - hookName, err := hooks.FetchHookName(a.config.Hook.MFA) + hookName, err := hooks.FetchHookName(a.config.Hook.MFAVerificationAttempt) if err != nil { return err } @@ -278,7 +278,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } valid := totp.Validate(params.Code, factor.Secret) - if config.Hook.MFA.Enabled { + if config.Hook.MFAVerificationAttempt.Enabled { input := hooks.MFAVerificationAttemptInput{ UserID: user.ID, FactorID: factor.ID, @@ -296,7 +296,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ // TODO: include extensibility point name "factor_id": factor.ID, - "URI": config.Hook.MFA.URI, + "URI": config.Hook.MFAVerificationAttempt.URI, }); terr != nil { return terr } diff --git a/internal/api/settings.go b/internal/api/settings.go index 9ccb58fc91..38c8a82413 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -73,7 +73,7 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { Zoom: config.External.Zoom.Enabled, }, HookConfiguration: HookSettings{ - MFAVerification: config.Hook.MFA.Enabled, + MFAVerification: config.Hook.MFAVerificationAttempt.Enabled, }, DisableSignup: config.DisableSignup, diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 644ac8798f..2c3c52cbbe 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -382,8 +382,7 @@ type WebhookConfig struct { // Moving away from the existing HookConfig so we can get a fresh start. type HookConfiguration struct { - // TODO (Joel): Fix the naming later - MFA ExtensibilityPointConfiguration `json:"mfa"` + MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"` } type ExtensibilityPointConfiguration struct { diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index bd8d4c3676..bda7e6a3fd 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -21,10 +21,10 @@ func TestGlobal(t *testing.T) { os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") os.Setenv("GOTRUE_JWT_SECRET", "secret") os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") - os.Setenv("GOTRUE_HOOK_MFA_URI", "pg-functions://postgres/auth/count_failed_attempts") + os.Setenv("GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI", "pg-functions://postgres/auth/count_failed_attempts") gc, err := LoadGlobal("") require.NoError(t, err) require.NotNil(t, gc) assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) - assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFA.URI) + assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFAVerificationAttempt.URI) } From e21a8060f04ea368d8d36779ce5cfb749d22eff4 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 12:06:15 +0800 Subject: [PATCH 17/42] refactor: rename HookErrorResponse->AuthHookErrorResponse --- internal/api/mfa.go | 2 +- internal/hooks/auth_hooks.go | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 1d44c32416..ca7691f9e3 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -218,7 +218,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { }); err != nil { return err } - hookResponseOrError := hooks.HookErrorResponse{} + hookResponseOrError := hooks.AuthHookErrorResponse{} err = json.Unmarshal(response, &hookResponseOrError) if err == nil && hookResponseOrError.IsError() { return err diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index cc7270c6df..5a18e8c130 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -45,12 +45,11 @@ const ( MFAVerificationAttempt = "auth.mfa_verfication" ) -// TODO: Give this a more proper name -type HookErrorResponse struct { +type AuthHookErrorResponse struct { AuthHookError } -func (hookError *HookErrorResponse) IsError() bool { +func (hookError *AuthHookErrorResponse) IsError() bool { return hookError.Message != "" } From b7a8c23de23bf8a2d8b6b7c9bcbd03055bef10ce Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 15:18:44 +0800 Subject: [PATCH 18/42] test: add initial tests --- internal/api/debug.test2619120917 | 0 internal/api/mfa.go | 23 ++++++----- internal/api/mfa_test.go | 65 +++++++++++++++++++++++++++++-- internal/hooks/auth_hooks.go | 2 + 4 files changed, 74 insertions(+), 16 deletions(-) create mode 100644 internal/api/debug.test2619120917 diff --git a/internal/api/debug.test2619120917 b/internal/api/debug.test2619120917 new file mode 100644 index 0000000000..e69de29bb2 diff --git a/internal/api/mfa.go b/internal/api/mfa.go index ca7691f9e3..8a69b70945 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -203,14 +203,17 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { func (a *API) invokeHook(ctx context.Context, input any, output any) error { switch input.(type) { case hooks.MFAVerificationAttemptInput: - var response []byte + payload, err := json.Marshal(&input) + if err != nil { + panic(err) + } hookName, err := hooks.FetchHookName(a.config.Hook.MFAVerificationAttempt) if err != nil { return err } if err := a.db.Transaction(func(tx *storage.Connection) error { - query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), string(input.([]byte))) - terr := query.First(&response) + query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), payload) + terr := query.First(&output) if terr != nil { return terr } @@ -219,7 +222,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { return err } hookResponseOrError := hooks.AuthHookErrorResponse{} - err = json.Unmarshal(response, &hookResponseOrError) + err = json.Unmarshal(output.([]byte), &hookResponseOrError) if err == nil && hookResponseOrError.IsError() { return err } @@ -285,18 +288,14 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { Valid: valid, } output := hooks.MFAVerificationAttemptOutput{} - payload, err := json.Marshal(&input) - if err != nil { - panic(err) - } - if err := a.invokeHook(ctx, payload, output); err != nil { + if err := a.invokeHook(ctx, input, &output); err != nil { return err } if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ - // TODO: include extensibility point name - "factor_id": factor.ID, - "URI": config.Hook.MFAVerificationAttempt.URI, + "extensibility_point_event": hooks.MFAVerificationAttempt, + "factor_id": factor.ID, + "URI": config.Hook.MFAVerificationAttempt.URI, }); terr != nil { return terr } diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 0d67287c32..020f89607e 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -522,24 +522,81 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR return verifyResp } +// TODO: refactor 4 cases into one long function func (ts *MFATestSuite) TestVerificationHookSuccess() { - // TODO + ts.Config.Hook.MFAVerificationAttempt.Enabled = true + // Pop executes as supabase_auth_admin and only has access to auth + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook" + verificationHookSQL := ` + create or replace function verification_hook(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'continue' + ); + end; + $$ language plpgsql; + ` + email := "testemail@gmail.com" + password := "testpassword" + // 3. Execute the SQL to create the function + err := ts.API.db.RawQuery(verificationHookSQL).Exec() + require.NoError(ts.T(), err) + token := signUpAndVerify(ts, email, password) + require.NotNil(ts.T(), token) + cleanupHookSQL := ` + drop function verification_hook(input jsonb) + ` + err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) } func (ts *MFATestSuite) TestVerificationHookReject() { - // TODO + ts.Config.Hook.MFAVerificationAttempt.Enabled = true + // Pop executes as supabase_auth_admin and only has access to auth + // ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook" + // verificationHookSQL := ` + // create or replace function verification_hook(input jsonb) + // returns json as $$ + // begin + // return json_build_object( + // 'decision', 'reject' + // ); + // end; + // $$ language plpgsql; + // ` + // email := "testemail@gmail.com" + // password := "testpassword" + // // 3. Execute the SQL to create the function + // err := ts.API.db.RawQuery(verificationHookSQL).Exec() + // require.NoError(ts.T(), err) + // token := signUpAndVerify(ts, email, password) + // require.Nil(ts.T(), token) + // cleanupHookSQL := ` + // drop function verification_hook(input jsonb) + // ` + // err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + // require.NoError(ts.T(), err) } func (ts *MFATestSuite) TestVerificationHookError() { + ts.Config.Hook.MFAVerificationAttempt.Enabled = true + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_error" // TODO } func (ts *MFATestSuite) TestVerificationHookTimeout() { - // TODO + ts.Config.Hook.MFAVerificationAttempt.Enabled = true + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_timeout" + // Call pg_sleep(10) + // TODO: expect an rror } func (ts *MFATestSuite) TestVerificationHookDisabled() { - // TODO + // The suite should default to false, but for illustration sake + ts.Config.Hook.MFAVerificationAttempt.Enabled = false + // resp := signUpAndVerify() + // Response should indicate failture } diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 5a18e8c130..d672c5308d 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -17,6 +17,8 @@ const ( HTTPHook HookType = "http" ) +// Hook Names + const ( MFAHookRejection = "reject" MFAHookContinue = "continue" From 5a4ce40bb35894cdf158a8afc5f113d104aa5a7a Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 16:35:25 +0800 Subject: [PATCH 19/42] feat: properly fetch response --- internal/api/mfa.go | 13 +++++--- internal/api/mfa_test.go | 69 +++++++++++++++++++++------------------- 2 files changed, 46 insertions(+), 36 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 8a69b70945..3be0b54b9a 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -201,6 +201,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { } func (a *API) invokeHook(ctx context.Context, input any, output any) error { + var response []byte switch input.(type) { case hooks.MFAVerificationAttemptInput: payload, err := json.Marshal(&input) @@ -213,7 +214,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { } if err := a.db.Transaction(func(tx *storage.Connection) error { query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), payload) - terr := query.First(&output) + terr := query.First(&response) if terr != nil { return terr } @@ -222,10 +223,14 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { return err } hookResponseOrError := hooks.AuthHookErrorResponse{} - err = json.Unmarshal(output.([]byte), &hookResponseOrError) + err = json.Unmarshal(response, &hookResponseOrError) if err == nil && hookResponseOrError.IsError() { return err } + if err = json.Unmarshal(response, output); err != nil { + return err + } + return nil default: return errors.New("invalid hook extensibility point") @@ -287,9 +292,9 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { FactorID: factor.ID, Valid: valid, } - output := hooks.MFAVerificationAttemptOutput{} + output := &hooks.MFAVerificationAttemptOutput{} - if err := a.invokeHook(ctx, input, &output); err != nil { + if err := a.invokeHook(ctx, input, output); err != nil { return err } if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 020f89607e..184b546991 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -384,7 +384,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password) + token := signUpAndVerify(ts, email, password, true /* Guarantee success */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -410,7 +410,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password) + token := signUpAndVerify(ts, email, password, true /* Guarantee Success */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -455,16 +455,16 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes return data } -func signUpAndVerify(ts *MFATestSuite, email, password string) (verifyResp *AccessTokenResponse) { +func signUpAndVerify(ts *MFATestSuite, email, password string, guaranteeSuccess bool) (verifyResp *AccessTokenResponse) { signUpResp := signUp(ts, email, password) - verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token) + verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token, guaranteeSuccess) return verifyResp } -func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyResp *AccessTokenResponse) { +func enrollAndVerify(ts *MFATestSuite, user *models.User, token string, guaranteeSuccess bool) (verifyResp *AccessTokenResponse) { var buffer bytes.Buffer w := httptest.NewRecorder() require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]string{"friendly_name": "john", "factor_type": models.TOTP, "issuer": ts.TestDomain})) @@ -491,7 +491,6 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR require.NoError(ts.T(), json.NewDecoder(x.Body).Decode(&challengeResp)) challengeID := challengeResp.ID - // Verify var verifyBuffer bytes.Buffer y := httptest.NewRecorder() @@ -516,13 +515,15 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string) (verifyR req.Header.Set("Content-Type", "application/json") ts.API.handler.ServeHTTP(y, req) - require.Equal(ts.T(), http.StatusOK, y.Code) + if guaranteeSuccess { + require.Equal(ts.T(), http.StatusOK, y.Code) + } verifyResp = &AccessTokenResponse{} require.NoError(ts.T(), json.NewDecoder(y.Body).Decode(&verifyResp)) return verifyResp } -// TODO: refactor 4 cases into one long function +// TODO: refactor 4 cases into one long function. Also consider how to cleanup if any fails func (ts *MFATestSuite) TestVerificationHookSuccess() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true // Pop executes as supabase_auth_admin and only has access to auth @@ -542,7 +543,7 @@ func (ts *MFATestSuite) TestVerificationHookSuccess() { // 3. Execute the SQL to create the function err := ts.API.db.RawQuery(verificationHookSQL).Exec() require.NoError(ts.T(), err) - token := signUpAndVerify(ts, email, password) + token := signUpAndVerify(ts, email, password, true /* Guarantee Success */) require.NotNil(ts.T(), token) cleanupHookSQL := ` drop function verification_hook(input jsonb) @@ -552,33 +553,37 @@ func (ts *MFATestSuite) TestVerificationHookSuccess() { } func (ts *MFATestSuite) TestVerificationHookReject() { + verificationError := "authentication attempt rejected" ts.Config.Hook.MFAVerificationAttempt.Enabled = true // Pop executes as supabase_auth_admin and only has access to auth - // ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook" - // verificationHookSQL := ` - // create or replace function verification_hook(input jsonb) - // returns json as $$ - // begin - // return json_build_object( - // 'decision', 'reject' - // ); - // end; - // $$ language plpgsql; - // ` - // email := "testemail@gmail.com" - // password := "testpassword" - // // 3. Execute the SQL to create the function - // err := ts.API.db.RawQuery(verificationHookSQL).Exec() - // require.NoError(ts.T(), err) - // token := signUpAndVerify(ts, email, password) - // require.Nil(ts.T(), token) - // cleanupHookSQL := ` - // drop function verification_hook(input jsonb) - // ` - // err = ts.API.db.RawQuery(cleanupHookSQL).Exec() - // require.NoError(ts.T(), err) + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook_reject" + verificationHookSQL := fmt.Sprintf(` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', %q + ); + end; + $$ language plpgsql; + `, verificationError) + email := "testemail@gmail.com" + password := "testpassword" + + err := ts.API.db.RawQuery(verificationHookSQL).Exec() + require.NoError(ts.T(), err) + resp := signUpAndVerify(ts, email, password, false /* Guarantee Success */) + // TODO: Figure out how to properly require a 403 and a proper error message here + require.Equal(ts.T(), "", resp.Token) + cleanupHookSQL := ` + drop function verification_hook_reject(input jsonb) + ` + err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) } + func (ts *MFATestSuite) TestVerificationHookError() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_error" From aa9d920cb32a4e6c71198929f477af5eee023156 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 17:17:27 +0800 Subject: [PATCH 20/42] feat: update tests --- internal/api/debug.test2619120917 | 0 internal/api/mfa_test.go | 89 +++++++++++++++++++------------ 2 files changed, 55 insertions(+), 34 deletions(-) delete mode 100644 internal/api/debug.test2619120917 diff --git a/internal/api/debug.test2619120917 b/internal/api/debug.test2619120917 deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 184b546991..c75703415c 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -553,35 +553,64 @@ func (ts *MFATestSuite) TestVerificationHookSuccess() { } func (ts *MFATestSuite) TestVerificationHookReject() { + cases := []struct { + desc string + enabled bool + expectedToken bool + }{ + { + desc: "No token returned when Hook is configured to reject", + enabled: true, + expectedToken: false, + }, + { + desc: "Token returned when Hook is disabled", + enabled: false, + expectedToken: true, + }, + } + verificationError := "authentication attempt rejected" - ts.Config.Hook.MFAVerificationAttempt.Enabled = true - // Pop executes as supabase_auth_admin and only has access to auth - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook_reject" - verificationHookSQL := fmt.Sprintf(` - create or replace function verification_hook_reject(input jsonb) - returns json as $$ - begin - return json_build_object( - 'decision', 'reject', - 'message', %q - ); - end; - $$ language plpgsql; - `, verificationError) - email := "testemail@gmail.com" - password := "testpassword" - err := ts.API.db.RawQuery(verificationHookSQL).Exec() - require.NoError(ts.T(), err) - resp := signUpAndVerify(ts, email, password, false /* Guarantee Success */) - // TODO: Figure out how to properly require a 403 and a proper error message here - require.Equal(ts.T(), "", resp.Token) - cleanupHookSQL := ` - drop function verification_hook_reject(input jsonb) - ` - err = ts.API.db.RawQuery(cleanupHookSQL).Exec() - require.NoError(ts.T(), err) + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled + // To ensure distinct emails + email := fmt.Sprintf("testemail%s@gmail.com", strings.ReplaceAll(c.desc, " ", "")) + password := "testpassword" + + if c.enabled { + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook_reject" + verificationHookSQL := fmt.Sprintf(` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', '%s' + ); + end; + $$ language plpgsql; + `, verificationError) + + err := ts.API.db.RawQuery(verificationHookSQL).Exec() + require.NoError(t, err) + } + + resp := signUpAndVerify(ts, email, password, false /* Guarantee Success */) + if c.expectedToken { + require.NotEqual(t, "", resp.Token) + } else { + require.Equal(t, "", resp.Token) + } + if c.enabled { + cleanupHookSQL := `drop function verification_hook_reject(input jsonb)` + err := ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(t, err) + } + }) + } } func (ts *MFATestSuite) TestVerificationHookError() { @@ -597,11 +626,3 @@ func (ts *MFATestSuite) TestVerificationHookTimeout() { // Call pg_sleep(10) // TODO: expect an rror } - -func (ts *MFATestSuite) TestVerificationHookDisabled() { - // The suite should default to false, but for illustration sake - ts.Config.Hook.MFAVerificationAttempt.Enabled = false - // resp := signUpAndVerify() - // Response should indicate failture - -} From 176ad4882c3c48d61ee928f962588e1774720216 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 17:37:56 +0800 Subject: [PATCH 21/42] feat: add a few tests --- internal/hooks/auth_hooks_test.go | 35 ++++++++++++++++++++++++++----- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/internal/hooks/auth_hooks_test.go b/internal/hooks/auth_hooks_test.go index b797eea0e6..68d6337eec 100644 --- a/internal/hooks/auth_hooks_test.go +++ b/internal/hooks/auth_hooks_test.go @@ -2,6 +2,7 @@ package hooks import ( "github.com/stretchr/testify/suite" + "github.com/supabase/gotrue/internal/conf" "testing" ) @@ -13,11 +14,35 @@ func TestHooks(t *testing.T) { ts := &HookTestSuite{} suite.Run(t, ts) } -func (ts *HookTestSuite) SetupTest() { - // TODO -} +func TestFetchHookName(ts *testing.T) { + cases := []struct { + desc string + uri string + expectedResult string + expectedError string // Using string to represent the error for simplicity + }{ + // Positive test cases + {desc: "Valid URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectedResult: "auth.verification_hook_reject", expectedError: ""}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectedResult: "user_management.add_user", expectedError: ""}, + + // Negative test cases + {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectedResult: "", expectedError: "invalid schema name: 123auth"}, + {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectedResult: "", expectedError: "invalid table name: 123verification_hook_reject"}, + {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectedResult: "", expectedError: "URI path does not contain enough parts"}, + // {desc: "Incorrect URI Format", uri: "http://postgres/auth/verification_hook_reject", expectedResult: "", expectedError: "invalid URI format"}, + } -func (ts *HookTestSuite) TestFetchHookName() { - // TODO + for _, tc := range cases { + ts.Run(tc.desc, func(t *testing.T) { + ep := conf.ExtensibilityPointConfiguration{URI: tc.uri} + result, err := FetchHookName(ep) + if err != nil && err.Error() != tc.expectedError { + t.Errorf("Test %s failed: expected error %v, got %v", tc.desc, tc.expectedError, err) + } + if result != tc.expectedResult { + t.Errorf("Test %s failed: expected result %v, got %v", tc.desc, tc.expectedResult, result) + } + }) + } } From 32b1a8d4ba598b22f8206346272038443dc17b5c Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 18:17:58 +0800 Subject: [PATCH 22/42] fix: update test structure --- internal/hooks/auth_hooks_test.go | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/internal/hooks/auth_hooks_test.go b/internal/hooks/auth_hooks_test.go index 68d6337eec..4baf457b28 100644 --- a/internal/hooks/auth_hooks_test.go +++ b/internal/hooks/auth_hooks_test.go @@ -1,6 +1,7 @@ package hooks import ( + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/gotrue/internal/conf" "testing" @@ -15,12 +16,12 @@ func TestHooks(t *testing.T) { suite.Run(t, ts) } -func TestFetchHookName(ts *testing.T) { +func (ts *HookTestSuite) TestFetchHookName() { cases := []struct { desc string uri string expectedResult string - expectedError string // Using string to represent the error for simplicity + expectedError string }{ // Positive test cases {desc: "Valid URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectedResult: "auth.verification_hook_reject", expectedError: ""}, @@ -30,18 +31,19 @@ func TestFetchHookName(ts *testing.T) { {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectedResult: "", expectedError: "invalid schema name: 123auth"}, {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectedResult: "", expectedError: "invalid table name: 123verification_hook_reject"}, {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectedResult: "", expectedError: "URI path does not contain enough parts"}, - // {desc: "Incorrect URI Format", uri: "http://postgres/auth/verification_hook_reject", expectedResult: "", expectedError: "invalid URI format"}, } for _, tc := range cases { - ts.Run(tc.desc, func(t *testing.T) { + ts.T().Run(tc.desc, func(t *testing.T) { ep := conf.ExtensibilityPointConfiguration{URI: tc.uri} result, err := FetchHookName(ep) - if err != nil && err.Error() != tc.expectedError { - t.Errorf("Test %s failed: expected error %v, got %v", tc.desc, tc.expectedError, err) - } - if result != tc.expectedResult { - t.Errorf("Test %s failed: expected result %v, got %v", tc.desc, tc.expectedResult, result) + if tc.expectedError == "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, result) + } else { + require.Error(t, err) + require.EqualError(t, err, tc.expectedError) + require.Empty(t, result) } }) } From b9d6ba7c2df5792bca956251d7495a29ec6acbac Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Mon, 27 Nov 2023 18:59:30 +0800 Subject: [PATCH 23/42] feat: add local timeout and more tests --- internal/api/mfa.go | 5 ++++ internal/api/mfa_test.go | 45 +++++++++++++++++++++++++++++++++--- internal/hooks/auth_hooks.go | 5 ++++ 3 files changed, 52 insertions(+), 3 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 3be0b54b9a..754ef5ef77 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -213,6 +213,11 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { return err } if err := a.db.Transaction(func(tx *storage.Connection) error { + + timeoutQuery := tx.RawQuery(fmt.Sprintf("SET LOCAL statement_timeout TO '%d';", hooks.DefaultTimeout)) + if terr := timeoutQuery.Exec(); terr != nil { + return terr + } query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), payload) terr := query.First(&response) if terr != nil { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index c75703415c..647cea3fc8 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -616,13 +616,52 @@ func (ts *MFATestSuite) TestVerificationHookReject() { func (ts *MFATestSuite) TestVerificationHookError() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_error" - // TODO + + errorHookSQL := ` + create or replace function test_verification_hook_error(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; + $$ language plpgsql;` + err := ts.API.db.RawQuery(errorHookSQL).Exec() + require.NoError(ts.T(), err) + email := "testemail_error@gmail.com" + password := "testpassword" + resp := signUpAndVerify(ts, email, password, false /* No guarantee of success */) + // TODO: Convert into proper assetions here instead of nilcheck + require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error + // TODO: Convert this into generic function + cleanupHookSQL := `drop function test_verification_hook_error(input jsonb)` + err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) } func (ts *MFATestSuite) TestVerificationHookTimeout() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_timeout" - // Call pg_sleep(10) - // TODO: expect an rror + + timeoutHookSQL := ` + create or replace function test_verification_hook_timeout(input jsonb) + returns json as $$ + begin + PERFORM pg_sleep(10); + return json_build_object( + 'decision', 'continue' + ); + end; + $$ language plpgsql;` + err := ts.API.db.RawQuery(timeoutHookSQL).Exec() + require.NoError(ts.T(), err) + + email := "testemail_error@gmail.com" + password := "testpassword" + resp := signUpAndVerify(ts, email, password, false /* No guarantee of success */) + require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error + + cleanupHookSQL := `drop function test_verification_hook_timeout(input jsonb)` + err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) + } diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index d672c5308d..c2a5e2fb17 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -8,6 +8,7 @@ import ( "github.com/supabase/gotrue/internal/conf" "regexp" "strings" + "time" ) type HookType string @@ -17,6 +18,10 @@ const ( HTTPHook HookType = "http" ) +const ( + DefaultTimeout = 2 * time.Second +) + // Hook Names const ( From b50b7ed4cbb754ef8751b28b81e4c11a104722eb Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 00:30:42 +0800 Subject: [PATCH 24/42] refactor: cut back on redundant code --- internal/api/mfa.go | 2 +- internal/api/mfa_test.go | 58 +++++++++++++++++++--------------------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 754ef5ef77..574785b56b 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -214,7 +214,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { } if err := a.db.Transaction(func(tx *storage.Connection) error { - timeoutQuery := tx.RawQuery(fmt.Sprintf("SET LOCAL statement_timeout TO '%d';", hooks.DefaultTimeout)) + timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) if terr := timeoutQuery.Exec(); terr != nil { return terr } diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 647cea3fc8..9b1faa045c 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -384,7 +384,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password, true /* Guarantee success */) + token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -410,7 +410,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password, true /* Guarantee Success */) + token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -455,10 +455,10 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes return data } -func signUpAndVerify(ts *MFATestSuite, email, password string, guaranteeSuccess bool) (verifyResp *AccessTokenResponse) { +func signUpAndVerify(ts *MFATestSuite, email, password string, isSuccessGuaranteed bool) (verifyResp *AccessTokenResponse) { signUpResp := signUp(ts, email, password) - verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token, guaranteeSuccess) + verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token, isSuccessGuaranteed) return verifyResp @@ -523,10 +523,9 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string, guarante return verifyResp } -// TODO: refactor 4 cases into one long function. Also consider how to cleanup if any fails -func (ts *MFATestSuite) TestVerificationHookSuccess() { +func (ts *MFATestSuite) TestVerificationHookDefaultSuccess() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true - // Pop executes as supabase_auth_admin and only has access to auth + // GoTrue executes as supabase_auth_admin and only has access to the Auth Schema ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook" verificationHookSQL := ` create or replace function verification_hook(input jsonb) @@ -540,10 +539,10 @@ func (ts *MFATestSuite) TestVerificationHookSuccess() { ` email := "testemail@gmail.com" password := "testpassword" - // 3. Execute the SQL to create the function + err := ts.API.db.RawQuery(verificationHookSQL).Exec() require.NoError(ts.T(), err) - token := signUpAndVerify(ts, email, password, true /* Guarantee Success */) + token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) require.NotNil(ts.T(), token) cleanupHookSQL := ` drop function verification_hook(input jsonb) @@ -552,25 +551,25 @@ func (ts *MFATestSuite) TestVerificationHookSuccess() { require.NoError(ts.T(), err) } -func (ts *MFATestSuite) TestVerificationHookReject() { +func (ts *MFATestSuite) TestVerificationHookDefaultReject() { cases := []struct { desc string enabled bool expectedToken bool }{ { - desc: "No token returned when Hook is configured to reject", + desc: "Rejection hook works when enabled", enabled: true, expectedToken: false, }, { - desc: "Token returned when Hook is disabled", + desc: "Rejection hook has no effect when disabled", enabled: false, expectedToken: true, }, } - verificationError := "authentication attempt rejected" + defaultVerificationErrorMessage := "authentication attempt rejected" for _, c := range cases { ts.T().Run(c.desc, func(t *testing.T) { @@ -591,31 +590,27 @@ func (ts *MFATestSuite) TestVerificationHookReject() { ); end; $$ language plpgsql; - `, verificationError) + `, defaultVerificationErrorMessage) err := ts.API.db.RawQuery(verificationHookSQL).Exec() require.NoError(t, err) } - resp := signUpAndVerify(ts, email, password, false /* Guarantee Success */) + resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) if c.expectedToken { require.NotEqual(t, "", resp.Token) } else { require.Equal(t, "", resp.Token) } - if c.enabled { - cleanupHookSQL := `drop function verification_hook_reject(input jsonb)` - err := ts.API.db.RawQuery(cleanupHookSQL).Exec() - require.NoError(t, err) - } + cleanupHook(ts, "verification_hook_reject(input jsonb)") }) } } func (ts *MFATestSuite) TestVerificationHookError() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_error" + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/test_verification_hook_error" errorHookSQL := ` create or replace function test_verification_hook_error(input jsonb) @@ -628,19 +623,15 @@ func (ts *MFATestSuite) TestVerificationHookError() { require.NoError(ts.T(), err) email := "testemail_error@gmail.com" password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* No guarantee of success */) + resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) // TODO: Convert into proper assetions here instead of nilcheck require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error - // TODO: Convert this into generic function - cleanupHookSQL := `drop function test_verification_hook_error(input jsonb)` - err = ts.API.db.RawQuery(cleanupHookSQL).Exec() - require.NoError(ts.T(), err) - + cleanupHook(ts, "test_verification_hook_error(input jsonb)") } func (ts *MFATestSuite) TestVerificationHookTimeout() { ts.Config.Hook.MFAVerificationAttempt.Enabled = true - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/public/test_verification_hook_timeout" + ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/test_verification_hook_timeout" timeoutHookSQL := ` create or replace function test_verification_hook_timeout(input jsonb) @@ -657,11 +648,16 @@ func (ts *MFATestSuite) TestVerificationHookTimeout() { email := "testemail_error@gmail.com" password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* No guarantee of success */) + resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error - cleanupHookSQL := `drop function test_verification_hook_timeout(input jsonb)` - err = ts.API.db.RawQuery(cleanupHookSQL).Exec() + cleanupHook(ts, "test_verification_hook_timeout(input jsonb)") + +} + +func cleanupHook(ts *MFATestSuite, hookName string) { + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName) + err := ts.API.db.RawQuery(cleanupHookSQL).Exec() require.NoError(ts.T(), err) } From 651151efc76298058123845b0072d6a9eb776cae Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 00:56:56 +0800 Subject: [PATCH 25/42] fix: patch tests --- internal/api/mfa_test.go | 2 +- internal/hooks/auth_hooks.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 9b1faa045c..6fb2bc1055 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -637,7 +637,7 @@ func (ts *MFATestSuite) TestVerificationHookTimeout() { create or replace function test_verification_hook_timeout(input jsonb) returns json as $$ begin - PERFORM pg_sleep(10); + PERFORM pg_sleep(3); return json_build_object( 'decision', 'continue' ); diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index c2a5e2fb17..ec08bcfa68 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -19,7 +19,8 @@ const ( ) const ( - DefaultTimeout = 2 * time.Second + // In Miliseconds + DefaultTimeout = 2000 ) // Hook Names From f9fa25e7a0ed8b473f625747d9f347f70764474d Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 01:19:34 +0800 Subject: [PATCH 26/42] fix: partial conversion to use Auth Hook structs --- internal/api/mfa.go | 22 ++++++++++++++-------- internal/hooks/auth_hooks.go | 9 ++++++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 574785b56b..e3a758fef6 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -13,7 +13,6 @@ import ( svg "github.com/ajstarks/svgo" "github.com/boombuler/barcode/qr" "github.com/gofrs/uuid" - "github.com/pkg/errors" "github.com/pquerna/otp/totp" "github.com/supabase/gotrue/internal/hooks" "github.com/supabase/gotrue/internal/metering" @@ -200,7 +199,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) invokeHook(ctx context.Context, input any, output any) error { +func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.AuthHookError { var response []byte switch input.(type) { case hooks.MFAVerificationAttemptInput: @@ -210,10 +209,11 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { } hookName, err := hooks.FetchHookName(a.config.Hook.MFAVerificationAttempt) if err != nil { - return err + return &hooks.AuthHookError{ + Message: "invalid hook name", + } } if err := a.db.Transaction(func(tx *storage.Connection) error { - timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) if terr := timeoutQuery.Exec(); terr != nil { return terr @@ -225,20 +225,26 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) error { } return nil }); err != nil { - return err + return &hooks.AuthHookError{ + Message: err.Error(), + } } hookResponseOrError := hooks.AuthHookErrorResponse{} err = json.Unmarshal(response, &hookResponseOrError) if err == nil && hookResponseOrError.IsError() { - return err + return &hookResponseOrError.AuthHookError } if err = json.Unmarshal(response, output); err != nil { - return err + return &hooks.AuthHookError{ + Message: "error unmarshalling response", + } } return nil default: - return errors.New("invalid hook extensibility point") + return &hooks.AuthHookError{ + Message: "invalid extensibility point", + } } } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index ec08bcfa68..37e7b571ae 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -8,7 +8,6 @@ import ( "github.com/supabase/gotrue/internal/conf" "regexp" "strings" - "time" ) type HookType string @@ -41,13 +40,17 @@ type MFAVerificationAttemptOutput struct { Message string `json:"message"` } -// AuthHookError is an error with a message and an HTTP status code. +// AuthHookError is an error with a message and an error code. type AuthHookError struct { - Code int `json:"code"` + Code string `json:"code"` Message string `json:"msg"` ErrorID string `json:"error_id,omitempty"` } +func (a *AuthHookError) Error() string { + return fmt.Sprintf("%s: %s", a.Code, a.Message) +} + // Hook Events const ( MFAVerificationAttempt = "auth.mfa_verfication" From 47504b09b9a98f5735045d22428c395bb1eb7260 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 01:23:38 +0800 Subject: [PATCH 27/42] feat: add initial Error Message --- internal/api/mfa.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index e3a758fef6..1a3b114562 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "net/http" @@ -306,7 +307,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output := &hooks.MFAVerificationAttemptOutput{} if err := a.invokeHook(ctx, input, output); err != nil { - return err + return errors.New(err.Error()) } if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ "extensibility_point_event": hooks.MFAVerificationAttempt, From fa1cbdee1bcdc8a2d1326994038cc2e6846c47dd Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 01:39:41 +0800 Subject: [PATCH 28/42] refactor: rename some vars --- internal/api/mfa_test.go | 16 ++++++++-------- internal/hooks/auth_hooks.go | 1 - 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 6fb2bc1055..f06ed41255 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -384,7 +384,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) + token := signUpAndVerify(ts, email, password, true /* <- requireStatusOK */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -410,7 +410,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { email := "test1@example.com" password := "test123" - token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) + token := signUpAndVerify(ts, email, password, true /* <- requireStatusOK */) ts.Config.Security.RefreshTokenRotationEnabled = true var buffer bytes.Buffer require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ @@ -455,10 +455,10 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes return data } -func signUpAndVerify(ts *MFATestSuite, email, password string, isSuccessGuaranteed bool) (verifyResp *AccessTokenResponse) { +func signUpAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) (verifyResp *AccessTokenResponse) { signUpResp := signUp(ts, email, password) - verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token, isSuccessGuaranteed) + verifyResp = enrollAndVerify(ts, signUpResp.User, signUpResp.Token, requireStatusOK) return verifyResp @@ -542,7 +542,7 @@ func (ts *MFATestSuite) TestVerificationHookDefaultSuccess() { err := ts.API.db.RawQuery(verificationHookSQL).Exec() require.NoError(ts.T(), err) - token := signUpAndVerify(ts, email, password, true /* <- isSuccessGuaranteed */) + token := signUpAndVerify(ts, email, password, true /* <- requireStatusOK */) require.NotNil(ts.T(), token) cleanupHookSQL := ` drop function verification_hook(input jsonb) @@ -596,7 +596,7 @@ func (ts *MFATestSuite) TestVerificationHookDefaultReject() { require.NoError(t, err) } - resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) + resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) if c.expectedToken { require.NotEqual(t, "", resp.Token) } else { @@ -623,7 +623,7 @@ func (ts *MFATestSuite) TestVerificationHookError() { require.NoError(ts.T(), err) email := "testemail_error@gmail.com" password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) + resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) // TODO: Convert into proper assetions here instead of nilcheck require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error cleanupHook(ts, "test_verification_hook_error(input jsonb)") @@ -648,7 +648,7 @@ func (ts *MFATestSuite) TestVerificationHookTimeout() { email := "testemail_error@gmail.com" password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* <- isSuccessGuaranteed */) + resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error cleanupHook(ts, "test_verification_hook_timeout(input jsonb)") diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 37e7b571ae..0c397c0a0c 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -40,7 +40,6 @@ type MFAVerificationAttemptOutput struct { Message string `json:"message"` } -// AuthHookError is an error with a message and an error code. type AuthHookError struct { Code string `json:"code"` Message string `json:"msg"` From c4470d7678821b5cd5f78606808d6ef30a9c9454 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 01:43:41 +0800 Subject: [PATCH 29/42] chore: small comments --- internal/api/mfa.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 1a3b114562..6544e1926e 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -215,6 +215,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth } } if err := a.db.Transaction(func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) if terr := timeoutQuery.Exec(); terr != nil { return terr @@ -230,6 +231,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth Message: err.Error(), } } + // As we the response fields aren't known to us we try to check if it's an error first. hookResponseOrError := hooks.AuthHookErrorResponse{} err = json.Unmarshal(response, &hookResponseOrError) if err == nil && hookResponseOrError.IsError() { From 2842b235878916614bf2e7f5bc4a26495d7f7f18 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 17:26:53 +0800 Subject: [PATCH 30/42] chore: add a default message --- internal/api/mfa.go | 19 +++++++------------ internal/hooks/auth_hooks.go | 11 +++++++++++ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 6544e1926e..e01f1631b7 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -210,9 +210,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth } hookName, err := hooks.FetchHookName(a.config.Hook.MFAVerificationAttempt) if err != nil { - return &hooks.AuthHookError{ - Message: "invalid hook name", - } + return hooks.HookError("invalid hook name") } if err := a.db.Transaction(func(tx *storage.Connection) error { // We rely on Postgres timeouts to ensure the function doesn't overrun @@ -227,9 +225,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth } return nil }); err != nil { - return &hooks.AuthHookError{ - Message: err.Error(), - } + return hooks.HookError(err.Error()) } // As we the response fields aren't known to us we try to check if it's an error first. hookResponseOrError := hooks.AuthHookErrorResponse{} @@ -238,16 +234,12 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth return &hookResponseOrError.AuthHookError } if err = json.Unmarshal(response, output); err != nil { - return &hooks.AuthHookError{ - Message: "error unmarshalling response", - } + return hooks.HookError("error unmarshalling response") } return nil default: - return &hooks.AuthHookError{ - Message: "invalid extensibility point", - } + return hooks.HookError("invalid extensibility point") } } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { @@ -322,6 +314,9 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err := models.Logout(a.db, user.ID); err != nil { return err } + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } return forbiddenError(output.Message) } } diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 0c397c0a0c..6b4ba7fb66 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -55,10 +55,21 @@ const ( MFAVerificationAttempt = "auth.mfa_verfication" ) +const ( + DefaultMFAHookRejectionMessage = "mfa attempt rejected" +) + type AuthHookErrorResponse struct { AuthHookError } +func HookError(message string, args ...interface{}) *AuthHookError { + return &AuthHookError{ + Message: fmt.Sprintf(message, args...), + } + +} + func (hookError *AuthHookErrorResponse) IsError() bool { return hookError.Message != "" } From 2260f963f89cb48cfaa83dae45992c4c333e7171 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 23:11:15 +0800 Subject: [PATCH 31/42] refactor: remove excess code --- internal/api/mfa.go | 11 ++----- internal/hooks/auth_hooks_test.go | 50 ------------------------------ internal/models/audit_log_entry.go | 3 -- 3 files changed, 2 insertions(+), 62 deletions(-) delete mode 100644 internal/hooks/auth_hooks_test.go diff --git a/internal/api/mfa.go b/internal/api/mfa.go index e01f1631b7..5bc8261ba6 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -239,7 +239,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth return nil default: - return hooks.HookError("invalid extensibility point") + panic("invalid extensibility point") } } func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { @@ -292,6 +292,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } valid := totp.Validate(params.Code, factor.Secret) + if config.Hook.MFAVerificationAttempt.Enabled { input := hooks.MFAVerificationAttemptInput{ UserID: user.ID, @@ -299,17 +300,9 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { Valid: valid, } output := &hooks.MFAVerificationAttemptOutput{} - if err := a.invokeHook(ctx, input, output); err != nil { return errors.New(err.Error()) } - if terr := models.NewAuditLogEntry(r, a.db, user, models.InvokeAuthHookAction, r.RemoteAddr, map[string]interface{}{ - "extensibility_point_event": hooks.MFAVerificationAttempt, - "factor_id": factor.ID, - "URI": config.Hook.MFAVerificationAttempt.URI, - }); terr != nil { - return terr - } if output.Decision == hooks.MFAHookRejection { if err := models.Logout(a.db, user.ID); err != nil { return err diff --git a/internal/hooks/auth_hooks_test.go b/internal/hooks/auth_hooks_test.go deleted file mode 100644 index 4baf457b28..0000000000 --- a/internal/hooks/auth_hooks_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package hooks - -import ( - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" - "github.com/supabase/gotrue/internal/conf" - "testing" -) - -type HookTestSuite struct { - suite.Suite -} - -func TestHooks(t *testing.T) { - ts := &HookTestSuite{} - suite.Run(t, ts) -} - -func (ts *HookTestSuite) TestFetchHookName() { - cases := []struct { - desc string - uri string - expectedResult string - expectedError string - }{ - // Positive test cases - {desc: "Valid URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectedResult: "auth.verification_hook_reject", expectedError: ""}, - {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectedResult: "user_management.add_user", expectedError: ""}, - - // Negative test cases - {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectedResult: "", expectedError: "invalid schema name: 123auth"}, - {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectedResult: "", expectedError: "invalid table name: 123verification_hook_reject"}, - {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectedResult: "", expectedError: "URI path does not contain enough parts"}, - } - - for _, tc := range cases { - ts.T().Run(tc.desc, func(t *testing.T) { - ep := conf.ExtensibilityPointConfiguration{URI: tc.uri} - result, err := FetchHookName(ep) - if tc.expectedError == "" { - require.NoError(t, err) - require.Equal(t, tc.expectedResult, result) - } else { - require.Error(t, err) - require.EqualError(t, err, tc.expectedError) - require.Empty(t, result) - } - }) - } -} diff --git a/internal/models/audit_log_entry.go b/internal/models/audit_log_entry.go index 8772baf757..74ad3d6668 100644 --- a/internal/models/audit_log_entry.go +++ b/internal/models/audit_log_entry.go @@ -40,7 +40,6 @@ const ( DeleteRecoveryCodesAction AuditAction = "recovery_codes_deleted" UpdateFactorAction AuditAction = "factor_updated" MFACodeLoginAction AuditAction = "mfa_code_login" - InvokeAuthHookAction AuditAction = "auth_hook_invoked" IdentityUnlinkAction AuditAction = "identity_unlinked" account auditLogType = "account" @@ -49,7 +48,6 @@ const ( user auditLogType = "user" factor auditLogType = "factor" recoveryCodes auditLogType = "recovery_codes" - authHook auditLogType = "auth_hook" ) var ActionLogTypeMap = map[AuditAction]auditLogType{ @@ -75,7 +73,6 @@ var ActionLogTypeMap = map[AuditAction]auditLogType{ UpdateFactorAction: factor, MFACodeLoginAction: factor, DeleteRecoveryCodesAction: recoveryCodes, - InvokeAuthHookAction: authHook, } // AuditLogEntry is the database model for audit log entries. From ddae946061e7b8b49727270c52fce43f59ad073c Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Tue, 28 Nov 2023 23:58:33 +0800 Subject: [PATCH 32/42] chore: convert fetchhookname into configuration load --- internal/api/mfa.go | 13 +++++++++--- internal/conf/configuration.go | 34 ++++++++++++++++++++++++++++++-- internal/hooks/auth_hooks.go | 36 ---------------------------------- 3 files changed, 42 insertions(+), 41 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 5bc8261ba6..f9f91173e6 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -7,8 +7,8 @@ import ( "errors" "fmt" "net/http" - "net/url" + "strings" "github.com/aaronarduino/goqrsvg" svg "github.com/ajstarks/svgo" @@ -208,10 +208,17 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth if err != nil { panic(err) } - hookName, err := hooks.FetchHookName(a.config.Hook.MFAVerificationAttempt) + + // TODO: maybe populate this on Config load instead + u, err := url.Parse(a.config.Hook.MFAVerificationAttempt.URI) if err != nil { - return hooks.HookError("invalid hook name") + return hooks.HookError(err.Error()) } + pathParts := strings.Split(u.Path, "/") + schema := pathParts[1] + table := pathParts[2] + hookName := fmt.Sprintf("%s.%s", schema, table) + if err := a.db.Transaction(func(tx *storage.Connection) error { // We rely on Postgres timeouts to ensure the function doesn't overrun timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 2c3c52cbbe..187b1c652a 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "regexp" "strings" "text/template" "time" @@ -390,15 +391,43 @@ type ExtensibilityPointConfiguration struct { Enabled bool `json:"true"` } +func (h *HookConfiguration) Validate() error { + points := []ExtensibilityPointConfiguration{ + h.MFAVerificationAttempt, + } + for _, point := range points { + if err := point.ValidateExtensibilityPoint(); err != nil { + return err + } + } + return nil +} + func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { if e.URI != "" { + regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` + re, err := regexp.Compile(regExp) + if err != nil { + return err + } + u, err := url.Parse(e.URI) if err != nil { - return errors.New("hook entry should be a valid URI") + return err } - if pathParts := strings.Split(u.Path, "/"); len(pathParts) < 3 { + pathParts := strings.Split(u.Path, "/") + if len(pathParts) < 3 { return fmt.Errorf("URI path does not contain enough parts") } + schema := pathParts[1] + table := pathParts[2] + // Validate schema and table names + if !re.MatchString(schema) { + return fmt.Errorf("invalid schema name: %s", schema) + } + if !re.MatchString(table) { + return fmt.Errorf("invalid table name: %s", table) + } } return nil } @@ -569,6 +598,7 @@ func (c *GlobalConfiguration) Validate() error { &c.SAML, &c.Security, &c.Sessions, + &c.Hook, } for _, validatable := range validatables { diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 6b4ba7fb66..22d4f3d084 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -1,20 +1,14 @@ package hooks import ( - "net/url" - "fmt" "github.com/gofrs/uuid" - "github.com/supabase/gotrue/internal/conf" - "regexp" - "strings" ) type HookType string const ( PostgresHook HookType = "pg-functions" - HTTPHook HookType = "http" ) const ( @@ -23,7 +17,6 @@ const ( ) // Hook Names - const ( MFAHookRejection = "reject" MFAHookContinue = "continue" @@ -73,32 +66,3 @@ func HookError(message string, args ...interface{}) *AuthHookError { func (hookError *AuthHookErrorResponse) IsError() bool { return hookError.Message != "" } - -func FetchHookName(ep conf.ExtensibilityPointConfiguration) (string, error) { - // specification for Postgres names - regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` - re, err := regexp.Compile(regExp) - if err != nil { - return "", err - } - - u, err := url.Parse(ep.URI) - if err != nil { - return "", err - } - pathParts := strings.Split(u.Path, "/") - if len(pathParts) < 3 { - return "", fmt.Errorf("URI path does not contain enough parts") - } - schema := pathParts[1] - table := pathParts[2] - // Validate schema and table names - if !re.MatchString(schema) { - return "", fmt.Errorf("invalid schema name: %s", schema) - } - if !re.MatchString(table) { - return "", fmt.Errorf("invalid table name: %s", table) - } - - return schema + "." + table, nil -} From 25f95f45ae66c05e69d979936cbe9ebff77e123d Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 29 Nov 2023 00:10:16 +0800 Subject: [PATCH 33/42] fix: light refactor of tests --- internal/api/mfa_test.go | 207 ++++++++++++++++++--------------------- 1 file changed, 95 insertions(+), 112 deletions(-) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index f06ed41255..b301cccfaf 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -523,138 +523,121 @@ func enrollAndVerify(ts *MFATestSuite, user *models.User, token string, guarante return verifyResp } -func (ts *MFATestSuite) TestVerificationHookDefaultSuccess() { - ts.Config.Hook.MFAVerificationAttempt.Enabled = true - // GoTrue executes as supabase_auth_admin and only has access to the Auth Schema - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook" - verificationHookSQL := ` - create or replace function verification_hook(input jsonb) - returns json as $$ - begin - return json_build_object( - 'decision', 'continue' - ); - end; - $$ language plpgsql; - ` - email := "testemail@gmail.com" - password := "testpassword" - - err := ts.API.db.RawQuery(verificationHookSQL).Exec() - require.NoError(ts.T(), err) - token := signUpAndVerify(ts, email, password, true /* <- requireStatusOK */) - require.NotNil(ts.T(), token) - cleanupHookSQL := ` - drop function verification_hook(input jsonb) - ` - err = ts.API.db.RawQuery(cleanupHookSQL).Exec() - require.NoError(ts.T(), err) -} - -func (ts *MFATestSuite) TestVerificationHookDefaultReject() { - cases := []struct { - desc string - enabled bool - expectedToken bool - }{ +func (ts *MFATestSuite) TestVerificationHooks() { + type verificationHookTestCase struct { + desc string + enabled bool + uri string + hookFunctionSQL string + emailSuffix string + expectToken bool + cleanupHookFunction string + } + cases := []verificationHookTestCase{ + { + desc: "Default Success", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook", + hookFunctionSQL: ` + create or replace function verification_hook(input jsonb) + returns json as $$ + begin + return json_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + emailSuffix: "success", + expectToken: true, + cleanupHookFunction: "verification_hook(input jsonb)", + }, + { + desc: "Error", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_error", + hookFunctionSQL: ` + create or replace function test_verification_hook_error(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + emailSuffix: "error", + expectToken: false, + cleanupHookFunction: "test_verification_hook_error(input jsonb)", + }, + { + desc: "Reject - Enabled", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_enabled", + expectToken: false, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, { - desc: "Rejection hook works when enabled", - enabled: true, - expectedToken: false, + desc: "Reject - Disabled", + enabled: false, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_disabled", + expectToken: true, + cleanupHookFunction: "verification_hook_reject(input jsonb)", }, { - desc: "Rejection hook has no effect when disabled", - enabled: false, - expectedToken: true, + desc: "Timeout", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_timeout", + hookFunctionSQL: ` + create or replace function test_verification_hook_timeout(input jsonb) + returns json as $$ + begin + PERFORM pg_sleep(3); + return json_build_object( + 'decision', 'continue' + ); + end; $$ language plpgsql;`, + emailSuffix: "timeout", + expectToken: false, // Assuming the test expects no token due to timeout + cleanupHookFunction: "test_verification_hook_timeout(input jsonb)", }, } - defaultVerificationErrorMessage := "authentication attempt rejected" - for _, c := range cases { ts.T().Run(c.desc, func(t *testing.T) { ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled - // To ensure distinct emails - email := fmt.Sprintf("testemail%s@gmail.com", strings.ReplaceAll(c.desc, " ", "")) - password := "testpassword" + ts.Config.Hook.MFAVerificationAttempt.URI = c.uri - if c.enabled { - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/verification_hook_reject" - verificationHookSQL := fmt.Sprintf(` - create or replace function verification_hook_reject(input jsonb) - returns json as $$ - begin - return json_build_object( - 'decision', 'reject', - 'message', '%s' - ); - end; - $$ language plpgsql; - `, defaultVerificationErrorMessage) - - err := ts.API.db.RawQuery(verificationHookSQL).Exec() - require.NoError(t, err) - } + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + email := fmt.Sprintf("testemail_%s@gmail.com", c.emailSuffix) + password := "testpassword" + resp := signUpAndVerify(ts, email, password, c.expectToken) - resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) - if c.expectedToken { + if c.expectToken { require.NotEqual(t, "", resp.Token) } else { require.Equal(t, "", resp.Token) } - cleanupHook(ts, "verification_hook_reject(input jsonb)") + cleanupHook(ts, c.cleanupHookFunction) }) } } -func (ts *MFATestSuite) TestVerificationHookError() { - ts.Config.Hook.MFAVerificationAttempt.Enabled = true - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/test_verification_hook_error" - - errorHookSQL := ` - create or replace function test_verification_hook_error(input jsonb) - returns json as $$ - begin - RAISE EXCEPTION 'Intentional Error for Testing'; - end; - $$ language plpgsql;` - err := ts.API.db.RawQuery(errorHookSQL).Exec() - require.NoError(ts.T(), err) - email := "testemail_error@gmail.com" - password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) - // TODO: Convert into proper assetions here instead of nilcheck - require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error - cleanupHook(ts, "test_verification_hook_error(input jsonb)") -} - -func (ts *MFATestSuite) TestVerificationHookTimeout() { - ts.Config.Hook.MFAVerificationAttempt.Enabled = true - ts.Config.Hook.MFAVerificationAttempt.URI = "pg-functions://postgres/auth/test_verification_hook_timeout" - - timeoutHookSQL := ` - create or replace function test_verification_hook_timeout(input jsonb) - returns json as $$ - begin - PERFORM pg_sleep(3); - return json_build_object( - 'decision', 'continue' - ); - end; - $$ language plpgsql;` - err := ts.API.db.RawQuery(timeoutHookSQL).Exec() - require.NoError(ts.T(), err) - - email := "testemail_error@gmail.com" - password := "testpassword" - resp := signUpAndVerify(ts, email, password, false /* <- requireStatusOK */) - require.Equal(ts.T(), "", resp.Token) // Assuming that the token is nil on error - - cleanupHook(ts, "test_verification_hook_timeout(input jsonb)") - -} - func cleanupHook(ts *MFATestSuite, hookName string) { cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName) err := ts.API.db.RawQuery(cleanupHookSQL).Exec() From af2c255a08da1e4e42b6634119ed3a22ad8973ec Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Wed, 29 Nov 2023 16:19:07 +0800 Subject: [PATCH 34/42] fix: add status code check --- internal/api/mfa_test.go | 13 ++++++++----- internal/conf/configuration.go | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 054d1f3d4a..4079eb5a2f 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -448,7 +448,6 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes } func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder { - signUpResp := signUp(ts, email, password) resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token, requireStatusOK) @@ -456,7 +455,6 @@ func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requir } - func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer string, expectedCode int) *httptest.ResponseRecorder { var buffer bytes.Buffer w := httptest.NewRecorder() @@ -541,6 +539,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { hookFunctionSQL string emailSuffix string expectToken bool + expectedCode int cleanupHookFunction string } cases := []verificationHookTestCase{ @@ -556,6 +555,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { end; $$ language plpgsql;`, emailSuffix: "success", expectToken: true, + expectedCode: http.StatusOK, cleanupHookFunction: "verification_hook(input jsonb)", }, { @@ -570,6 +570,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { end; $$ language plpgsql;`, emailSuffix: "error", expectToken: false, + expectedCode: http.StatusInternalServerError, cleanupHookFunction: "test_verification_hook_error(input jsonb)", }, { @@ -587,6 +588,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { end; $$ language plpgsql;`, emailSuffix: "reject_enabled", expectToken: false, + expectedCode: http.StatusForbidden, cleanupHookFunction: "verification_hook_reject(input jsonb)", }, { @@ -604,6 +606,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { end; $$ language plpgsql;`, emailSuffix: "reject_disabled", expectToken: true, + expectedCode: http.StatusOK, cleanupHookFunction: "verification_hook_reject(input jsonb)", }, { @@ -620,7 +623,8 @@ func (ts *MFATestSuite) TestVerificationHooks() { ); end; $$ language plpgsql;`, emailSuffix: "timeout", - expectToken: false, // Assuming the test expects no token due to timeout + expectToken: false, + expectedCode: http.StatusInternalServerError, cleanupHookFunction: "test_verification_hook_timeout(input jsonb)", }, } @@ -636,12 +640,11 @@ func (ts *MFATestSuite) TestVerificationHooks() { email := fmt.Sprintf("testemail_%s@gmail.com", c.emailSuffix) password := "testpassword" resp := performTestSignupAndVerify(ts, email, password, c.expectToken) + require.Equal(ts.T(), c.expectedCode, resp.Code) accessTokenResp := &AccessTokenResponse{} require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) - if c.expectToken { - require.NotEqual(t, "", accessTokenResp.Token) } else { require.Equal(t, "", accessTokenResp.Token) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 635413571d..3426f13bcb 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -194,7 +194,7 @@ type GlobalConfiguration struct { Sms SmsProviderConfiguration `json:"sms"` DisableSignup bool `json:"disable_signup" split_words:"true"` Webhook WebhookConfig `json:"webhook" split_words:"true"` - Hook HookConfiguration `json:"hook" split_words:"true"` + Hook HookConfiguration `json:"hook" split_words:"true"` Security SecurityConfiguration `json:"security"` Sessions SessionsConfiguration `json:"sessions"` MFA MFAConfiguration `json:"MFA"` From 7b874b85243190cc5234f6fdc79bf40444982f19 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 01:33:42 +0800 Subject: [PATCH 35/42] refactor: use errors --- internal/api/mfa.go | 30 ++++++++++++++++-------------- internal/hooks/auth_hooks.go | 17 +++++++++++++---- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index f9f91173e6..93f35ed44b 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -200,7 +200,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.AuthHookError { +func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) (hooks.HookOutput, *hooks.AuthHookError) { var response []byte switch input.(type) { case hooks.MFAVerificationAttemptInput: @@ -212,7 +212,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth // TODO: maybe populate this on Config load instead u, err := url.Parse(a.config.Hook.MFAVerificationAttempt.URI) if err != nil { - return hooks.HookError(err.Error()) + return nil, hooks.HookError(err.Error()) } pathParts := strings.Split(u.Path, "/") schema := pathParts[1] @@ -232,19 +232,16 @@ func (a *API) invokeHook(ctx context.Context, input any, output any) *hooks.Auth } return nil }); err != nil { - return hooks.HookError(err.Error()) + return nil, hooks.HookError(err.Error()) } - // As we the response fields aren't known to us we try to check if it's an error first. - hookResponseOrError := hooks.AuthHookErrorResponse{} - err = json.Unmarshal(response, &hookResponseOrError) - if err == nil && hookResponseOrError.IsError() { - return &hookResponseOrError.AuthHookError + if err = json.Unmarshal(response, &output); err != nil { + return nil, hooks.HookError(err.Error()) } - if err = json.Unmarshal(response, output); err != nil { - return hooks.HookError("error unmarshalling response") + if output.IsError() { + return nil, hooks.HookError(output.Error()) } - return nil + return output, nil default: panic("invalid extensibility point") } @@ -307,14 +304,19 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { Valid: valid, } output := &hooks.MFAVerificationAttemptOutput{} - if err := a.invokeHook(ctx, input, output); err != nil { + response, err := a.invokeHook(ctx, input, output) + if err != nil { return errors.New(err.Error()) } - if output.Decision == hooks.MFAHookRejection { + mfaOutput, ok := response.(*hooks.MFAVerificationAttemptOutput) + if !ok { + return errors.New("unexpected response type") + } + if mfaOutput.Decision == hooks.MFAHookRejection { if err := models.Logout(a.db, user.ID); err != nil { return err } - if output.Message == "" { + if mfaOutput.Message == "" { output.Message = hooks.DefaultMFAHookRejectionMessage } return forbiddenError(output.Message) diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 22d4f3d084..58b57c31e5 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -22,6 +22,11 @@ const ( MFAHookContinue = "continue" ) +type HookOutput interface { + IsError() bool + Error() string +} + type MFAVerificationAttemptInput struct { UserID uuid.UUID `json:"user_id"` FactorID uuid.UUID `json:"factor_id"` @@ -29,8 +34,9 @@ type MFAVerificationAttemptInput struct { } type MFAVerificationAttemptOutput struct { - Decision string `json:"decision"` - Message string `json:"message"` + Decision string `json:"decision"` + Message string `json:"message"` + HookError AuthHookError `json:"hook_error" split_words:"true"` } type AuthHookError struct { @@ -63,6 +69,9 @@ func HookError(message string, args ...interface{}) *AuthHookError { } -func (hookError *AuthHookErrorResponse) IsError() bool { - return hookError.Message != "" +func (mf *MFAVerificationAttemptOutput) IsError() bool { + return mf.HookError.Message != "" +} +func (mf *MFAVerificationAttemptOutput) Error() string { + return mf.HookError.Message } From 447de5f08e88c634a7d801b0b5c0f8b8529a2530 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 01:42:35 +0800 Subject: [PATCH 36/42] refactor: shove config back to configuration --- internal/api/mfa.go | 13 +------------ internal/api/mfa_test.go | 1 + internal/conf/configuration.go | 10 ++++++---- 3 files changed, 8 insertions(+), 16 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 93f35ed44b..9b763f270d 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -8,7 +8,6 @@ import ( "fmt" "net/http" "net/url" - "strings" "github.com/aaronarduino/goqrsvg" svg "github.com/ajstarks/svgo" @@ -209,23 +208,13 @@ func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput panic(err) } - // TODO: maybe populate this on Config load instead - u, err := url.Parse(a.config.Hook.MFAVerificationAttempt.URI) - if err != nil { - return nil, hooks.HookError(err.Error()) - } - pathParts := strings.Split(u.Path, "/") - schema := pathParts[1] - table := pathParts[2] - hookName := fmt.Sprintf("%s.%s", schema, table) - if err := a.db.Transaction(func(tx *storage.Connection) error { // We rely on Postgres timeouts to ensure the function doesn't overrun timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) if terr := timeoutQuery.Exec(); terr != nil { return terr } - query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", hookName), payload) + query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload) terr := query.First(&response) if terr != nil { return terr diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 4079eb5a2f..80abfcd3ba 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -633,6 +633,7 @@ func (ts *MFATestSuite) TestVerificationHooks() { ts.T().Run(c.desc, func(t *testing.T) { ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled ts.Config.Hook.MFAVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.ValidateAndPopulateExtensibilityPoint()) err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() require.NoError(t, err) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 3426f13bcb..c7d0cbafa4 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -422,8 +422,9 @@ type HookConfiguration struct { } type ExtensibilityPointConfiguration struct { - URI string `json:"uri"` - Enabled bool `json:"true"` + URI string `json:"uri"` + Enabled bool `json:"enabled"` + HookName string `json:"hook_name"` } func (h *HookConfiguration) Validate() error { @@ -431,14 +432,14 @@ func (h *HookConfiguration) Validate() error { h.MFAVerificationAttempt, } for _, point := range points { - if err := point.ValidateExtensibilityPoint(); err != nil { + if err := point.ValidateAndPopulateExtensibilityPoint(); err != nil { return err } } return nil } -func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { +func (e *ExtensibilityPointConfiguration) ValidateAndPopulateExtensibilityPoint() error { if e.URI != "" { regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` re, err := regexp.Compile(regExp) @@ -463,6 +464,7 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { if !re.MatchString(table) { return fmt.Errorf("invalid table name: %s", table) } + e.HookName = fmt.Sprintf("%s.%s", schema, table) } return nil } From 18ccfc66ceab29a5e1854d935349f48a693c911e Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Thu, 30 Nov 2023 09:04:08 +0800 Subject: [PATCH 37/42] Update internal/conf/configuration.go Co-authored-by: Stojan Dimitrovski --- internal/conf/configuration.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index c7d0cbafa4..76e906edd5 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -464,7 +464,7 @@ func (e *ExtensibilityPointConfiguration) ValidateAndPopulateExtensibilityPoint( if !re.MatchString(table) { return fmt.Errorf("invalid table name: %s", table) } - e.HookName = fmt.Sprintf("%s.%s", schema, table) + e.HookName = fmt.Sprintf("%q.%q", schema, table) } return nil } From be4fd32f1d4d2df8fe4c1c8328fa9c5c331946c4 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Thu, 30 Nov 2023 09:04:18 +0800 Subject: [PATCH 38/42] Update internal/hooks/auth_hooks.go Co-authored-by: Stojan Dimitrovski --- internal/hooks/auth_hooks.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 58b57c31e5..7dd4df9fd4 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -55,7 +55,7 @@ const ( ) const ( - DefaultMFAHookRejectionMessage = "mfa attempt rejected" + DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." ) type AuthHookErrorResponse struct { From a15631b6ff243b9913615e0f054392478234fd49 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 09:34:13 +0800 Subject: [PATCH 39/42] refactor: use error interface instead --- internal/api/mfa.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 9b763f270d..2523b66596 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -199,7 +199,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) (hooks.HookOutput, *hooks.AuthHookError) { +func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) error { var response []byte switch input.(type) { case hooks.MFAVerificationAttemptInput: @@ -221,16 +221,16 @@ func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput } return nil }); err != nil { - return nil, hooks.HookError(err.Error()) + return err } if err = json.Unmarshal(response, &output); err != nil { - return nil, hooks.HookError(err.Error()) + return err } if output.IsError() { - return nil, hooks.HookError(output.Error()) + return &output.(*hooks.MFAVerificationAttemptOutput).HookError } - return output, nil + return nil default: panic("invalid extensibility point") } @@ -293,19 +293,16 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { Valid: valid, } output := &hooks.MFAVerificationAttemptOutput{} - response, err := a.invokeHook(ctx, input, output) + err := a.invokeHook(ctx, input, output) if err != nil { return errors.New(err.Error()) } - mfaOutput, ok := response.(*hooks.MFAVerificationAttemptOutput) - if !ok { - return errors.New("unexpected response type") - } - if mfaOutput.Decision == hooks.MFAHookRejection { + + if output.Decision == hooks.MFAHookRejection { if err := models.Logout(a.db, user.ID); err != nil { return err } - if mfaOutput.Message == "" { + if output.Message == "" { output.Message = hooks.DefaultMFAHookRejectionMessage } return forbiddenError(output.Message) From 2d512cbf11e2c5a27663e8fcfc2cb84ac5c645bc Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 15:07:19 +0800 Subject: [PATCH 40/42] fix: make schema constant --- internal/conf/configuration.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 76e906edd5..8fe1d94343 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -20,6 +20,8 @@ const defaultMinPasswordLength int = 6 const defaultChallengeExpiryDuration float64 = 300 const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second +var postgresNamesRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`) + // Time is used to represent timestamps in the configuration, as envconfig has // trouble parsing empty strings, due to time.Time.UnmarshalText(). type Time struct { @@ -441,12 +443,6 @@ func (h *HookConfiguration) Validate() error { func (e *ExtensibilityPointConfiguration) ValidateAndPopulateExtensibilityPoint() error { if e.URI != "" { - regExp := `^[a-zA-Z_][a-zA-Z0-9_]{0,62}$` - re, err := regexp.Compile(regExp) - if err != nil { - return err - } - u, err := url.Parse(e.URI) if err != nil { return err @@ -458,10 +454,10 @@ func (e *ExtensibilityPointConfiguration) ValidateAndPopulateExtensibilityPoint( schema := pathParts[1] table := pathParts[2] // Validate schema and table names - if !re.MatchString(schema) { + if !postgresNamesRegexp.MatchString(schema) { return fmt.Errorf("invalid schema name: %s", schema) } - if !re.MatchString(table) { + if !postgresNamesRegexp.MatchString(table) { return fmt.Errorf("invalid table name: %s", table) } e.HookName = fmt.Sprintf("%q.%q", schema, table) From 3c99bdebedf27d0ecaf13b606a1fa4b2fa81af65 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 17:17:33 +0800 Subject: [PATCH 41/42] test: reinstate test suite --- internal/conf/configuration_test.go | 30 +++++++++++++++++++++++++++++ internal/hooks/auth_hooks.go | 5 ----- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index 8ba46becd6..5d83362ed9 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -97,3 +97,33 @@ func TestPasswordRequiredCharactersDecode(t *testing.T) { require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) } } + +func TestValidateAndPopulateExtensibilityPoint(t *testing.T) { + cases := []struct { + desc string + uri string + expectedResult string + }{ + // Positive test cases + {desc: "Valid URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectedResult: `"auth"."verification_hook_reject"`}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectedResult: `"user_management"."add_user"`}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectedResult: `"MySpeCial"."FUNCTION_THAT_YELLS_AT_YOU"`}, + + // Negative test cases + {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectedResult: ""}, + {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectedResult: ""}, + {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectedResult: ""}, + } + + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: tc.uri} + err := ep.ValidateAndPopulateExtensibilityPoint() + if tc.expectedResult != "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, ep.HookName) + } else { + require.Error(t, err) + require.Empty(t, ep.HookName) + } + } +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 7dd4df9fd4..19e9c1c6b9 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -49,11 +49,6 @@ func (a *AuthHookError) Error() string { return fmt.Sprintf("%s: %s", a.Code, a.Message) } -// Hook Events -const ( - MFAVerificationAttempt = "auth.mfa_verfication" -) - const ( DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." ) From 3699c0152d4a7988114b4d48448ae5f35396a174 Mon Sep 17 00:00:00 2001 From: "joel@joellee.org" Date: Thu, 30 Nov 2023 17:30:07 +0800 Subject: [PATCH 42/42] fix: remove unused structs --- internal/api/mfa.go | 2 +- internal/hooks/auth_hooks.go | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 2523b66596..3f253f0e15 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -214,7 +214,7 @@ func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput if terr := timeoutQuery.Exec(); terr != nil { return terr } - query := tx.RawQuery(fmt.Sprintf("SELECT * from %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload) + query := tx.RawQuery(fmt.Sprintf("SELECT %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload) terr := query.First(&response) if terr != nil { return terr diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index 19e9c1c6b9..0c5c87b819 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -53,10 +53,6 @@ const ( DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." ) -type AuthHookErrorResponse struct { - AuthHookError -} - func HookError(message string, args ...interface{}) *AuthHookError { return &AuthHookError{ Message: fmt.Sprintf(message, args...),