diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 57130438c0..3f253f0e15 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -2,9 +2,11 @@ package api import ( "bytes" + "context" "encoding/json" + "errors" + "fmt" "net/http" - "net/url" "github.com/aaronarduino/goqrsvg" @@ -12,6 +14,7 @@ import ( "github.com/boombuler/barcode/qr" "github.com/gofrs/uuid" "github.com/pquerna/otp/totp" + "github.com/supabase/gotrue/internal/hooks" "github.com/supabase/gotrue/internal/metering" "github.com/supabase/gotrue/internal/models" "github.com/supabase/gotrue/internal/storage" @@ -196,6 +199,42 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } +func (a *API) invokeHook(ctx context.Context, input any, output hooks.HookOutput) error { + var response []byte + switch input.(type) { + case hooks.MFAVerificationAttemptInput: + payload, err := json.Marshal(&input) + if err != nil { + panic(err) + } + + if err := a.db.Transaction(func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun + timeoutQuery := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)) + if terr := timeoutQuery.Exec(); terr != nil { + return terr + } + query := tx.RawQuery(fmt.Sprintf("SELECT %s(?)", a.config.Hook.MFAVerificationAttempt.HookName), payload) + terr := query.First(&response) + if terr != nil { + return terr + } + return nil + }); err != nil { + return err + } + if err = json.Unmarshal(response, &output); err != nil { + return err + } + if output.IsError() { + return &output.(*hooks.MFAVerificationAttemptOutput).HookError + } + + return nil + default: + panic("invalid extensibility point") + } +} func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { var err error ctx := r.Context() @@ -245,7 +284,31 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) } - if valid := totp.Validate(params.Code, factor.Secret); !valid { + valid := totp.Validate(params.Code, factor.Secret) + + if config.Hook.MFAVerificationAttempt.Enabled { + input := hooks.MFAVerificationAttemptInput{ + UserID: user.ID, + FactorID: factor.ID, + Valid: valid, + } + output := &hooks.MFAVerificationAttemptOutput{} + err := a.invokeHook(ctx, input, output) + if err != nil { + return errors.New(err.Error()) + } + + if output.Decision == hooks.MFAHookRejection { + if err := models.Logout(a.db, user.ID); err != nil { + return err + } + if output.Message == "" { + output.Message = hooks.DefaultMFAHookRejectionMessage + } + return forbiddenError(output.Message) + } + } + if !valid { return badRequestError("Invalid TOTP code entered") } diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 8d4d6ef73d..80abfcd3ba 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -371,7 +371,7 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() { // Integration Tests func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { - resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword) + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) accessTokenResp := &AccessTokenResponse{} require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) @@ -399,7 +399,7 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() { // Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() { - resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword) + resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */) accessTokenResp := &AccessTokenResponse{} require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) @@ -447,10 +447,9 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes return data } -func performTestSignupAndVerify(ts *MFATestSuite, email, password string) *httptest.ResponseRecorder { - +func performTestSignupAndVerify(ts *MFATestSuite, email, password string, requireStatusOK bool) *httptest.ResponseRecorder { signUpResp := signUp(ts, email, password) - resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token) + resp := performEnrollAndVerify(ts, signUpResp.User, signUpResp.Token, requireStatusOK) return resp @@ -468,9 +467,9 @@ func performEnrollFlow(ts *MFATestSuite, token, friendlyName, factorType, issuer ts.API.handler.ServeHTTP(w, req) require.Equal(ts.T(), expectedCode, w.Code) return w - } -func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int) *httptest.ResponseRecorder { + +func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token string, expectedCode int, requireStatusOK bool) *httptest.ResponseRecorder { var verifyBuffer bytes.Buffer y := httptest.NewRecorder() @@ -495,7 +494,9 @@ func performVerifyFlow(ts *MFATestSuite, challengeID, factorID uuid.UUID, token req.Header.Set("Content-Type", "application/json") ts.API.handler.ServeHTTP(y, req) - require.Equal(ts.T(), http.StatusOK, y.Code) + if requireStatusOK { + require.Equal(ts.T(), http.StatusOK, y.Code) + } return y } @@ -511,7 +512,7 @@ func performChallengeFlow(ts *MFATestSuite, factorID uuid.UUID, token string) *h } -func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) *httptest.ResponseRecorder { +func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string, requireStatusOK bool) *httptest.ResponseRecorder { w := performEnrollFlow(ts, token, "", models.TOTP, ts.TestDomain, http.StatusOK) enrollResp := EnrollFactorResponse{} require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) @@ -525,7 +526,139 @@ func performEnrollAndVerify(ts *MFATestSuite, user *models.User, token string) * challengeID := challengeResp.ID // Verify - y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK) + y := performVerifyFlow(ts, challengeID, factorID, token, http.StatusOK, requireStatusOK) return y } + +func (ts *MFATestSuite) TestVerificationHooks() { + type verificationHookTestCase struct { + desc string + enabled bool + uri string + hookFunctionSQL string + emailSuffix string + expectToken bool + expectedCode int + cleanupHookFunction string + } + cases := []verificationHookTestCase{ + { + desc: "Default Success", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook", + hookFunctionSQL: ` + create or replace function verification_hook(input jsonb) + returns json as $$ + begin + return json_build_object('decision', 'continue'); + end; $$ language plpgsql;`, + emailSuffix: "success", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook(input jsonb)", + }, + { + desc: "Error", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_error", + hookFunctionSQL: ` + create or replace function test_verification_hook_error(input jsonb) + returns json as $$ + begin + RAISE EXCEPTION 'Intentional Error for Testing'; + end; $$ language plpgsql;`, + emailSuffix: "error", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_error(input jsonb)", + }, + { + desc: "Reject - Enabled", + enabled: true, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_enabled", + expectToken: false, + expectedCode: http.StatusForbidden, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Reject - Disabled", + enabled: false, + uri: "pg-functions://postgres/auth/verification_hook_reject", + hookFunctionSQL: ` + create or replace function verification_hook_reject(input jsonb) + returns json as $$ + begin + return json_build_object( + 'decision', 'reject', + 'message', 'authentication attempt rejected' + ); + end; $$ language plpgsql;`, + emailSuffix: "reject_disabled", + expectToken: true, + expectedCode: http.StatusOK, + cleanupHookFunction: "verification_hook_reject(input jsonb)", + }, + { + desc: "Timeout", + enabled: true, + uri: "pg-functions://postgres/auth/test_verification_hook_timeout", + hookFunctionSQL: ` + create or replace function test_verification_hook_timeout(input jsonb) + returns json as $$ + begin + PERFORM pg_sleep(3); + return json_build_object( + 'decision', 'continue' + ); + end; $$ language plpgsql;`, + emailSuffix: "timeout", + expectToken: false, + expectedCode: http.StatusInternalServerError, + cleanupHookFunction: "test_verification_hook_timeout(input jsonb)", + }, + } + + for _, c := range cases { + ts.T().Run(c.desc, func(t *testing.T) { + ts.Config.Hook.MFAVerificationAttempt.Enabled = c.enabled + ts.Config.Hook.MFAVerificationAttempt.URI = c.uri + require.NoError(ts.T(), ts.Config.Hook.MFAVerificationAttempt.ValidateAndPopulateExtensibilityPoint()) + + err := ts.API.db.RawQuery(c.hookFunctionSQL).Exec() + require.NoError(t, err) + + email := fmt.Sprintf("testemail_%s@gmail.com", c.emailSuffix) + password := "testpassword" + resp := performTestSignupAndVerify(ts, email, password, c.expectToken) + require.Equal(ts.T(), c.expectedCode, resp.Code) + accessTokenResp := &AccessTokenResponse{} + require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp)) + + if c.expectToken { + require.NotEqual(t, "", accessTokenResp.Token) + } else { + require.Equal(t, "", accessTokenResp.Token) + } + + cleanupHook(ts, c.cleanupHookFunction) + }) + } +} + +func cleanupHook(ts *MFATestSuite, hookName string) { + cleanupHookSQL := fmt.Sprintf("drop function if exists %s", hookName) + err := ts.API.db.RawQuery(cleanupHookSQL).Exec() + require.NoError(ts.T(), err) + +} diff --git a/internal/api/settings.go b/internal/api/settings.go index c2ab0df39a..38c8a82413 100644 --- a/internal/api/settings.go +++ b/internal/api/settings.go @@ -36,6 +36,11 @@ type Settings struct { SmsProvider string `json:"sms_provider"` MFAEnabled bool `json:"mfa_enabled"` SAMLEnabled bool `json:"saml_enabled"` + HookConfiguration HookSettings `json:"hook"` +} + +type HookSettings struct { + MFAVerification bool `json:"mfa"` } func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { @@ -67,6 +72,9 @@ func (a *API) Settings(w http.ResponseWriter, r *http.Request) error { Phone: config.External.Phone.Enabled, Zoom: config.External.Zoom.Enabled, }, + HookConfiguration: HookSettings{ + MFAVerification: config.Hook.MFAVerificationAttempt.Enabled, + }, DisableSignup: config.DisableSignup, MailerAutoconfirm: config.Mailer.Autoconfirm, diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 2cda8a0815..6006e69926 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -6,6 +6,7 @@ import ( "fmt" "net/url" "os" + "regexp" "strings" "text/template" "time" @@ -19,6 +20,8 @@ const defaultMinPasswordLength int = 6 const defaultChallengeExpiryDuration float64 = 300 const defaultFlowStateExpiryDuration time.Duration = 300 * time.Second +var postgresNamesRegexp = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]{0,62}$`) + // Time is used to represent timestamps in the configuration, as envconfig has // trouble parsing empty strings, due to time.Time.UnmarshalText(). type Time struct { @@ -213,6 +216,7 @@ type GlobalConfiguration struct { Sms SmsProviderConfiguration `json:"sms"` DisableSignup bool `json:"disable_signup" split_words:"true"` Webhook WebhookConfig `json:"webhook" split_words:"true"` + Hook HookConfiguration `json:"hook" split_words:"true"` Security SecurityConfiguration `json:"security"` Sessions SessionsConfiguration `json:"sessions"` MFA MFAConfiguration `json:"MFA"` @@ -435,6 +439,53 @@ type WebhookConfig struct { Events []string `json:"events"` } +// Moving away from the existing HookConfig so we can get a fresh start. +type HookConfiguration struct { + MFAVerificationAttempt ExtensibilityPointConfiguration `json:"mfa_verification_attempt" split_words:"true"` +} + +type ExtensibilityPointConfiguration struct { + URI string `json:"uri"` + Enabled bool `json:"enabled"` + HookName string `json:"hook_name"` +} + +func (h *HookConfiguration) Validate() error { + points := []ExtensibilityPointConfiguration{ + h.MFAVerificationAttempt, + } + for _, point := range points { + if err := point.ValidateAndPopulateExtensibilityPoint(); err != nil { + return err + } + } + return nil +} + +func (e *ExtensibilityPointConfiguration) ValidateAndPopulateExtensibilityPoint() error { + if e.URI != "" { + u, err := url.Parse(e.URI) + if err != nil { + return err + } + pathParts := strings.Split(u.Path, "/") + if len(pathParts) < 3 { + return fmt.Errorf("URI path does not contain enough parts") + } + schema := pathParts[1] + table := pathParts[2] + // Validate schema and table names + if !postgresNamesRegexp.MatchString(schema) { + return fmt.Errorf("invalid schema name: %s", schema) + } + if !postgresNamesRegexp.MatchString(table) { + return fmt.Errorf("invalid table name: %s", table) + } + e.HookName = fmt.Sprintf("%q.%q", schema, table) + } + return nil +} + func (w *WebhookConfig) HasEvent(event string) bool { for _, name := range w.Events { if event == name { @@ -480,7 +531,6 @@ func LoadGlobal(filename string) (*GlobalConfiguration, error) { } config.Sms.SMSTemplate = template } - return config, nil } @@ -602,6 +652,7 @@ func (c *GlobalConfiguration) Validate() error { &c.SAML, &c.Security, &c.Sessions, + &c.Hook, } for _, validatable := range validatables { diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index a9178085c9..5d83362ed9 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -15,16 +15,18 @@ func TestMain(m *testing.M) { func TestGlobal(t *testing.T) { os.Setenv("GOTRUE_SITE_URL", "http://localhost:8080") - os.Setenv("GOTRUE_DB_DRIVER", "mysql") + os.Setenv("GOTRUE_DB_DRIVER", "postgres") os.Setenv("GOTRUE_DB_DATABASE_URL", "fake") os.Setenv("GOTRUE_OPERATOR_TOKEN", "token") os.Setenv("GOTRUE_API_REQUEST_ID_HEADER", "X-Request-ID") os.Setenv("GOTRUE_JWT_SECRET", "secret") os.Setenv("API_EXTERNAL_URL", "http://localhost:9999") + os.Setenv("GOTRUE_HOOK_MFA_VERIFICATION_ATTEMPT_URI", "pg-functions://postgres/auth/count_failed_attempts") gc, err := LoadGlobal("") require.NoError(t, err) require.NotNil(t, gc) assert.Equal(t, "X-Request-ID", gc.API.RequestIDHeader) + assert.Equal(t, "pg-functions://postgres/auth/count_failed_attempts", gc.Hook.MFAVerificationAttempt.URI) } func TestPasswordRequiredCharactersDecode(t *testing.T) { @@ -95,3 +97,33 @@ func TestPasswordRequiredCharactersDecode(t *testing.T) { require.Equal(t, []string(into), example.Result, "Example %d got unexpected result", i) } } + +func TestValidateAndPopulateExtensibilityPoint(t *testing.T) { + cases := []struct { + desc string + uri string + expectedResult string + }{ + // Positive test cases + {desc: "Valid URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectedResult: `"auth"."verification_hook_reject"`}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectedResult: `"user_management"."add_user"`}, + {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectedResult: `"MySpeCial"."FUNCTION_THAT_YELLS_AT_YOU"`}, + + // Negative test cases + {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectedResult: ""}, + {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectedResult: ""}, + {desc: "Insufficient Path Parts", uri: "pg-functions://postgres/auth", expectedResult: ""}, + } + + for _, tc := range cases { + ep := ExtensibilityPointConfiguration{URI: tc.uri} + err := ep.ValidateAndPopulateExtensibilityPoint() + if tc.expectedResult != "" { + require.NoError(t, err) + require.Equal(t, tc.expectedResult, ep.HookName) + } else { + require.Error(t, err) + require.Empty(t, ep.HookName) + } + } +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go new file mode 100644 index 0000000000..0c5c87b819 --- /dev/null +++ b/internal/hooks/auth_hooks.go @@ -0,0 +1,68 @@ +package hooks + +import ( + "fmt" + "github.com/gofrs/uuid" +) + +type HookType string + +const ( + PostgresHook HookType = "pg-functions" +) + +const ( + // In Miliseconds + DefaultTimeout = 2000 +) + +// Hook Names +const ( + MFAHookRejection = "reject" + MFAHookContinue = "continue" +) + +type HookOutput interface { + IsError() bool + Error() string +} + +type MFAVerificationAttemptInput struct { + UserID uuid.UUID `json:"user_id"` + FactorID uuid.UUID `json:"factor_id"` + Valid bool `json:"valid"` +} + +type MFAVerificationAttemptOutput struct { + Decision string `json:"decision"` + Message string `json:"message"` + HookError AuthHookError `json:"hook_error" split_words:"true"` +} + +type AuthHookError struct { + Code string `json:"code"` + Message string `json:"msg"` + ErrorID string `json:"error_id,omitempty"` +} + +func (a *AuthHookError) Error() string { + return fmt.Sprintf("%s: %s", a.Code, a.Message) +} + +const ( + DefaultMFAHookRejectionMessage = "Further MFA verification attempts will be rejected." +) + +func HookError(message string, args ...interface{}) *AuthHookError { + return &AuthHookError{ + Message: fmt.Sprintf(message, args...), + } + +} + +func (mf *MFAVerificationAttemptOutput) IsError() bool { + return mf.HookError.Message != "" +} +func (mf *MFAVerificationAttemptOutput) Error() string { + return mf.HookError.Message +}