From 822fb93faab325ba3d4bb628dff43381d68d0b5d Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Fri, 19 Jul 2024 19:41:33 +0200 Subject: [PATCH] fix: refactor mfa models and add observability to loadFactor (#1669) ## What kind of change does this PR introduce? Makes two changes: - Ensures we track DB calls to loadFactor - Changes the interface for creating a challenge in attempt to make clear that Challenge should always have a Factor --- internal/api/admin.go | 6 ++++-- internal/api/mfa.go | 2 +- internal/api/mfa_test.go | 2 +- internal/models/challenge.go | 11 ----------- internal/models/factor.go | 10 ++++++++++ 5 files changed, 16 insertions(+), 15 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index f4b774ec7e..7df41fb656 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -70,6 +70,8 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, } func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { + ctx := r.Context() + db := a.db.WithContext(ctx) factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID") @@ -77,14 +79,14 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex observability.LogEntrySetField(r, "factor_id", factorID) - f, err := models.FindFactorByFactorID(a.db, factorID) + f, err := models.FindFactorByFactorID(db, factorID) if err != nil { if models.IsNotFoundError(err) { return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } - return withFactor(r.Context(), f), nil + return withFactor(ctx, f), nil } func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { diff --git a/internal/api/mfa.go b/internal/api/mfa.go index df70c6b51f..c15f4a3b4d 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -182,7 +182,7 @@ func (a *API) ChallengeFactor(w http.ResponseWriter, r *http.Request) error { user := getUser(ctx) factor := getFactor(ctx) ipAddress := utilities.GetIPAddress(r) - challenge := models.NewChallenge(factor, ipAddress) + challenge := factor.CreateChallenge(ipAddress) if err := db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(challenge); terr != nil { diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 63f813249a..101f6c6606 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -265,7 +265,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) testIPAddress := utilities.GetIPAddress(req) - c := models.NewChallenge(f, testIPAddress) + c := f.CreateChallenge(testIPAddress) require.NoError(ts.T(), ts.API.db.Create(c), "Error saving new test challenge") if !v.validChallenge { // Set challenge creation so that it has expired in present time. diff --git a/internal/models/challenge.go b/internal/models/challenge.go index d52132aaad..c088f4b999 100644 --- a/internal/models/challenge.go +++ b/internal/models/challenge.go @@ -22,17 +22,6 @@ func (Challenge) TableName() string { return tableName } -func NewChallenge(factor *Factor, ipAddress string) *Challenge { - id := uuid.Must(uuid.NewV4()) - - challenge := &Challenge{ - ID: id, - FactorID: factor.ID, - IPAddress: ipAddress, - } - return challenge -} - func FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { var challenge Challenge err := conn.Find(&challenge, challengeID) diff --git a/internal/models/factor.go b/internal/models/factor.go index 53fddc260d..6abd00080c 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -187,6 +187,16 @@ func DeleteUnverifiedFactors(tx *storage.Connection, user *User) error { return nil } +func (f *Factor) CreateChallenge(ipAddress string) *Challenge { + id := uuid.Must(uuid.NewV4()) + challenge := &Challenge{ + ID: id, + FactorID: f.ID, + IPAddress: ipAddress, + } + return challenge +} + // UpdateFriendlyName changes the friendly name func (f *Factor) UpdateFriendlyName(tx *storage.Connection, friendlyName string) error { f.FriendlyName = friendlyName