From a856e6aef235acd0aada71a55b10067a5e0acda9 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Fri, 12 Jan 2024 22:47:15 +0800 Subject: [PATCH] refactor: move hooks from mfa.go to hooks.go (#1373) ## What kind of change does this PR introduce? Moves the existing `hooks.go` logic to the `hooks` package. Co-authored-by: joel@joellee.org --- internal/api/hooks.go | 138 ++++++++++++++++++++++++++++++++++++++++++ internal/api/mfa.go | 138 ------------------------------------------ 2 files changed, 138 insertions(+), 138 deletions(-) diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 5c73981dda..2fd31983e2 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -6,6 +6,7 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "fmt" "io" "net" "net/http" @@ -17,6 +18,7 @@ import ( jwt "github.com/golang-jwt/jwt" "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" @@ -287,3 +289,139 @@ type connectionWatcher struct { func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { c.gotConn = true } + +func (a *API) runHook(ctx context.Context, name string, input, output any) ([]byte, error) { + db := a.db.WithContext(ctx) + + request, err := json.Marshal(input) + if err != nil { + panic(err) + } + + var response []byte + if err := db.Transaction(func(tx *storage.Connection) error { + // We rely on Postgres timeouts to ensure the function doesn't overrun + if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil { + return terr + } + + if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil { + return terr + } + + // reset the timeout + if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { + return terr + } + + return nil + }); err != nil { + return nil, err + } + + if err := json.Unmarshal(response, output); err != nil { + return response, err + } + + return response, nil +} + +func (a *API) invokeHook(ctx context.Context, input, output any) error { + config := a.config + switch input.(type) { + case *hooks.MFAVerificationAttemptInput: + hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.MFAVerificationAttemptOutput") + } + + if _, err := a.runHook(ctx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { + return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) + } + + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + Code: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + + return nil + case *hooks.PasswordVerificationAttemptInput: + hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput) + if !ok { + panic("output should be *hooks.PasswordVerificationAttemptOutput") + } + + if _, err := a.runHook(ctx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { + return internalServerError("Error invoking password verification hook.").WithInternalError(err) + } + + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + Code: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + + return nil + case *hooks.CustomAccessTokenInput: + hookOutput, ok := output.(*hooks.CustomAccessTokenOutput) + if !ok { + panic("output should be *hooks.CustomAccessTokenOutput") + } + + if _, err := a.runHook(ctx, config.Hook.CustomAccessToken.HookName, input, output); err != nil { + return internalServerError("Error invoking access token hook.").WithInternalError(err) + } + + if hookOutput.IsError() { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + Code: httpCode, + Message: hookOutput.HookError.Message, + } + + return httpError.WithInternalError(&hookOutput.HookError) + } + if err := validateTokenClaims(hookOutput.Claims); err != nil { + httpCode := hookOutput.HookError.HTTPCode + + if httpCode == 0 { + httpCode = http.StatusInternalServerError + } + + httpError := &HTTPError{ + Code: httpCode, + Message: err.Error(), + } + + return httpError + } + return nil + + default: + panic("unknown hook input type") + } +} diff --git a/internal/api/mfa.go b/internal/api/mfa.go index dc6aa44f81..66505d0d27 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -2,9 +2,7 @@ package api import ( "bytes" - "context" "encoding/json" - "fmt" "net/http" "net/url" @@ -198,142 +196,6 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { }) } -func (a *API) runHook(ctx context.Context, name string, input, output any) ([]byte, error) { - db := a.db.WithContext(ctx) - - request, err := json.Marshal(input) - if err != nil { - panic(err) - } - - var response []byte - if err := db.Transaction(func(tx *storage.Connection) error { - // We rely on Postgres timeouts to ensure the function doesn't overrun - if terr := tx.RawQuery(fmt.Sprintf("set local statement_timeout TO '%d';", hooks.DefaultTimeout)).Exec(); terr != nil { - return terr - } - - if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil { - return terr - } - - // reset the timeout - if terr := tx.RawQuery("set local statement_timeout TO default;").Exec(); terr != nil { - return terr - } - - return nil - }); err != nil { - return nil, err - } - - if err := json.Unmarshal(response, output); err != nil { - return response, err - } - - return response, nil -} - -func (a *API) invokeHook(ctx context.Context, input, output any) error { - config := a.config - switch input.(type) { - case *hooks.MFAVerificationAttemptInput: - hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) - if !ok { - panic("output should be *hooks.MFAVerificationAttemptOutput") - } - - if _, err := a.runHook(ctx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { - return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) - } - - if hookOutput.IsError() { - httpCode := hookOutput.HookError.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - - httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, - } - - return httpError.WithInternalError(&hookOutput.HookError) - } - - return nil - case *hooks.PasswordVerificationAttemptInput: - hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput) - if !ok { - panic("output should be *hooks.PasswordVerificationAttemptOutput") - } - - if _, err := a.runHook(ctx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { - return internalServerError("Error invoking password verification hook.").WithInternalError(err) - } - - if hookOutput.IsError() { - httpCode := hookOutput.HookError.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - - httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, - } - - return httpError.WithInternalError(&hookOutput.HookError) - } - - return nil - case *hooks.CustomAccessTokenInput: - hookOutput, ok := output.(*hooks.CustomAccessTokenOutput) - if !ok { - panic("output should be *hooks.CustomAccessTokenOutput") - } - - if _, err := a.runHook(ctx, config.Hook.CustomAccessToken.HookName, input, output); err != nil { - return internalServerError("Error invoking access token hook.").WithInternalError(err) - } - - if hookOutput.IsError() { - httpCode := hookOutput.HookError.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - - httpError := &HTTPError{ - Code: httpCode, - Message: hookOutput.HookError.Message, - } - - return httpError.WithInternalError(&hookOutput.HookError) - } - if err := validateTokenClaims(hookOutput.Claims); err != nil { - httpCode := hookOutput.HookError.HTTPCode - - if httpCode == 0 { - httpCode = http.StatusInternalServerError - } - - httpError := &HTTPError{ - Code: httpCode, - Message: err.Error(), - } - - return httpError - } - return nil - - default: - panic("unknown hook input type") - } -} - func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { var err error ctx := r.Context()