From 72765b0997a6ecbf33463c896cbaf3bcac07e7fc Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Mon, 13 Nov 2023 17:37:11 -0500 Subject: [PATCH] feat: update automatic linking algorithm --- internal/api/external.go | 31 +++---- internal/models/linking.go | 138 +++++++++++++++++++------------- internal/models/linking_test.go | 71 ++++++++++++---- 3 files changed, 146 insertions(+), 94 deletions(-) diff --git a/internal/api/external.go b/internal/api/external.go index 8b4a099bd8..88d0b4ae8d 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -262,8 +262,6 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. aud := a.requestAud(ctx, r) config := a.config - var terr error - var user *models.User var identity *models.Identity var identityData map[string]interface{} @@ -271,19 +269,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. identityData = structs.Map(userData.Metadata) } - var emailData provider.Email - var emails []string - // an oauth identity with an unverified email will not have an email present - for _, email := range userData.Emails { - if email.Verified || config.Mailer.Autoconfirm { - if email.Primary { - emailData = email - } - emails = append(emails, strings.ToLower(email.Email)) - } - } - - decision, terr := models.DetermineAccountLinking(tx, providerType, userData.Metadata.Subject, emails) + decision, terr := models.DetermineAccountLinking(tx, userData.Emails, aud, providerType, userData.Metadata.Subject) if terr != nil { return nil, terr } @@ -307,15 +293,17 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. params := &SignupParams{ Provider: providerType, - Email: emailData.Email, + Email: decision.CandidateEmail.Email, Aud: aud, Data: identityData, } - isSSOUser := strings.HasPrefix(providerType, "sso:") + isSSOUser := false + if decision.LinkingDomain == "sso" { + isSSOUser = true + } - user, terr = a.signupNewUser(ctx, tx, params, isSSOUser) - if terr != nil { + if user, terr = a.signupNewUser(ctx, tx, params, isSSOUser); terr != nil { return nil, terr } @@ -357,7 +345,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. if terr = user.RemoveUnconfirmedIdentities(tx, identity); terr != nil { return nil, internalServerError("Error updating user").WithInternalError(terr) } - if emailData.Verified || config.Mailer.Autoconfirm { + if decision.CandidateEmail.Verified || config.Mailer.Autoconfirm { if terr := models.NewAuditLogEntry(r, tx, user, models.UserSignedUpAction, "", map[string]interface{}{ "provider": providerType, }); terr != nil { @@ -372,8 +360,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. return nil, internalServerError("Error updating user").WithInternalError(terr) } } else { - if user.Email != "" { - // an oauth identity with an unverified email will not have an email present + if decision.CandidateEmail.Email != "" { mailer := a.Mailer(ctx) referrer := utilities.GetReferrer(r, config) externalURL := getExternalHost(ctx) diff --git a/internal/models/linking.go b/internal/models/linking.go index 167da5bbca..9e18e172e2 100644 --- a/internal/models/linking.go +++ b/internal/models/linking.go @@ -3,6 +3,7 @@ package models import ( "strings" + p "github.com/supabase/gotrue/internal/api/provider" "github.com/supabase/gotrue/internal/storage" ) @@ -34,12 +35,11 @@ const ( ) type AccountLinkingResult struct { - Decision AccountLinkingDecision - - User *User - Identities []*Identity - - LinkingDomain string + Decision AccountLinkingDecision + User *User + Identities []*Identity + LinkingDomain string + CandidateEmail p.Email } // DetermineAccountLinking uses the provided data and database state to compute a decision on whether: @@ -49,7 +49,19 @@ type AccountLinkingResult struct { // - It's not possible to decide due to data inconsistency (MultipleAccounts) and the caller should decide // // Errors signal failure in processing only, like database access errors. -func DetermineAccountLinking(tx *storage.Connection, provider, sub string, emails []string) (AccountLinkingResult, error) { +func DetermineAccountLinking(tx *storage.Connection, emails []p.Email, aud, provider, sub string) (AccountLinkingResult, error) { + var verifiedEmails []string + var candidateEmail p.Email + for _, email := range emails { + if email.Verified { + verifiedEmails = append(verifiedEmails, strings.ToLower(email.Email)) + } + if email.Primary { + candidateEmail = email + candidateEmail.Email = strings.ToLower(email.Email) + } + } + if identity, terr := FindIdentityByIdAndProvider(tx, sub, provider); terr == nil { // account exists @@ -59,58 +71,65 @@ func DetermineAccountLinking(tx *storage.Connection, provider, sub string, email } return AccountLinkingResult{ - Decision: AccountExists, - User: user, - Identities: []*Identity{identity}, - LinkingDomain: GetAccountLinkingDomain(provider), + Decision: AccountExists, + User: user, + Identities: []*Identity{identity}, + LinkingDomain: GetAccountLinkingDomain(provider), + CandidateEmail: candidateEmail, }, nil } else if !IsNotFoundError(terr) { return AccountLinkingResult{}, terr } - // account does not exist, identity and user not immediately - // identifiable, look for similar identities based on email - var similarIdentities []*Identity - var similarUsers []*User + // the identity does not exist, so we need to check if we should create a new account + // or link to an existing one - if len(emails) > 0 { - if terr := tx.Q().Eager().Where("email ilike any (?)", emails).All(&similarIdentities); terr != nil { + // this is the linking domain for the new identity + candidateLinkingDomain := GetAccountLinkingDomain(provider) + if len(verifiedEmails) == 0 { + // if there are no verified emails, we always decide to create a new account + user, terr := IsDuplicatedEmail(tx, candidateEmail.Email, aud, nil) + if terr != nil { return AccountLinkingResult{}, terr } - - if !strings.HasPrefix(provider, "sso:") { - // there can be multiple user accounts with the same email when is_sso_user is true - // so we just do not consider those similar user accounts - if terr := tx.Q().Eager().Where("email ilike any (?) and is_sso_user is false", emails).All(&similarUsers); terr != nil { - return AccountLinkingResult{}, terr + if user != nil { + candidateEmail = p.Email{ + Email: "", + Verified: false, + Primary: false, } } - } - - // TODO: determine linking behavior over phone too - if len(similarIdentities) == 0 && len(similarUsers) == 0 { - // there are no similar identities, clearly we have to create a new account - return AccountLinkingResult{ - Decision: CreateAccount, - LinkingDomain: GetAccountLinkingDomain(provider), + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } - // there are some similar identities, we now need to proceed in - // identifying whether this supposed new identity should be assigned to - // an existing user or to create a new user, according to the automatic - // linking rules + var similarIdentities []*Identity + var similarUsers []*User + // look for similar identities and users based on email + if terr := tx.Q().Eager().Where("email ilike any (?)", verifiedEmails).All(&similarIdentities); terr != nil { + return AccountLinkingResult{}, terr + } - // this is the linking domain for the new identity - newAccountLinkingDomain := GetAccountLinkingDomain(provider) + if !strings.HasPrefix(provider, "sso:") { + // there can be multiple user accounts with the same email when is_sso_user is true + // so we just do not consider those similar user accounts + if terr := tx.Q().Eager().Where("email ilike any (?) and is_sso_user is false", verifiedEmails).All(&similarUsers); terr != nil { + return AccountLinkingResult{}, terr + } + } + // Need to check if the new identity should be assigned to an + // existing user or to create a new user, according to the automatic + // linking rules var linkingIdentities []*Identity // now let's see if there are any existing and similar identities in // the same linking domain for _, identity := range similarIdentities { - if GetAccountLinkingDomain(identity.Provider) == newAccountLinkingDomain { + if GetAccountLinkingDomain(identity.Provider) == candidateLinkingDomain { linkingIdentities = append(linkingIdentities, identity) } } @@ -121,24 +140,27 @@ func DetermineAccountLinking(tx *storage.Connection, provider, sub string, email // so we link this new identity to the user // TODO: Backfill the missing identity for the user return AccountLinkingResult{ - Decision: LinkAccount, - User: similarUsers[0], - Identities: linkingIdentities, - LinkingDomain: newAccountLinkingDomain, + Decision: LinkAccount, + User: similarUsers[0], + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } else if len(similarUsers) > 1 { // this shouldn't happen since there is a partial unique index on (email and is_sso_user = false) return AccountLinkingResult{ - Decision: MultipleAccounts, - Identities: linkingIdentities, - LinkingDomain: newAccountLinkingDomain, + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } else { // there are no identities in the linking domain, we have to // create a new identity and new user return AccountLinkingResult{ - Decision: CreateAccount, - LinkingDomain: newAccountLinkingDomain, + Decision: CreateAccount, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } } @@ -146,16 +168,17 @@ func DetermineAccountLinking(tx *storage.Connection, provider, sub string, email // there is at least one identity in the linking domain let's do a // sanity check to see if all of the identities in the domain share the // same user ID - + linkingUserId := linkingIdentities[0].UserID for _, identity := range linkingIdentities { - if identity.UserID != linkingIdentities[0].UserID { + if identity.UserID != linkingUserId { // ok this linking domain has more than one user account // caller should decide what to do return AccountLinkingResult{ - Decision: MultipleAccounts, - Identities: linkingIdentities, - LinkingDomain: newAccountLinkingDomain, + Decision: MultipleAccounts, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } } @@ -166,14 +189,15 @@ func DetermineAccountLinking(tx *storage.Connection, provider, sub string, email var user *User var terr error - if user, terr = FindUserByID(tx, linkingIdentities[0].UserID); terr != nil { + if user, terr = FindUserByID(tx, linkingUserId); terr != nil { return AccountLinkingResult{}, terr } return AccountLinkingResult{ - Decision: LinkAccount, - User: user, - Identities: linkingIdentities, - LinkingDomain: newAccountLinkingDomain, + Decision: LinkAccount, + User: user, + Identities: linkingIdentities, + LinkingDomain: candidateLinkingDomain, + CandidateEmail: candidateEmail, }, nil } diff --git a/internal/models/linking_test.go b/internal/models/linking_test.go index ecb80dc80e..321cdd00ec 100644 --- a/internal/models/linking_test.go +++ b/internal/models/linking_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "github.com/supabase/gotrue/internal/api/provider" "github.com/supabase/gotrue/internal/conf" "github.com/supabase/gotrue/internal/storage" "github.com/supabase/gotrue/internal/storage/test" @@ -37,13 +38,18 @@ func TestAccountLinking(t *testing.T) { func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionNoAccounts() { // when there are no accounts in the system -- conventional provider - decision, err := DetermineAccountLinking(ts.db, "provider", "abcdefgh", []string{"test@example.com"}) + testEmail := provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + } + decision, err := DetermineAccountLinking(ts.db, []provider.Email{testEmail}, "", "provider", "abcdefgh") require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, CreateAccount) // when there are no accounts in the system -- SSO provider - decision, err = DetermineAccountLinking(ts.db, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh", []string{"test@example.com"}) + decision, err = DetermineAccountLinking(ts.db, []provider.Email{testEmail}, "", "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh") require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, CreateAccount) @@ -72,13 +78,25 @@ func (ts *AccountLinkingTestSuite) TestCreateAccountDecisionWithAccounts() { require.NoError(ts.T(), ts.db.Create(identityB)) // when the email doesn't exist in the system -- conventional provider - decision, err := DetermineAccountLinking(ts.db, "provider", "abcdefgh", []string{"other@example.com"}) + decision, err := DetermineAccountLinking(ts.db, []provider.Email{ + provider.Email{ + Email: "other@example.com", + Verified: true, + Primary: true, + }, + }, "authenticated", "provider", "abcdefgh") require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, CreateAccount) // when looking for an email that doesn't exist in the SSO linking domain - decision, err = DetermineAccountLinking(ts.db, "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh", []string{"other@samltest.id"}) + decision, err = DetermineAccountLinking(ts.db, []provider.Email{ + provider.Email{ + Email: "other@samltest.id", + Verified: true, + Primary: true, + }, + }, "authenticated", "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", "abcdefgh") require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, CreateAccount) @@ -95,7 +113,13 @@ func (ts *AccountLinkingTestSuite) TestAccountExists() { require.NoError(ts.T(), err) require.NoError(ts.T(), ts.db.Create(identityA)) - decision, err := DetermineAccountLinking(ts.db, "provider", userA.ID.String(), []string{"test@example.com"}) + decision, err := DetermineAccountLinking(ts.db, []provider.Email{ + provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, "authenticated", "provider", userA.ID.String()) require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, AccountExists) @@ -126,29 +150,41 @@ func (ts *AccountLinkingTestSuite) TestLinkAccountExists() { cases := []struct { desc string - email string + email provider.Email sub string provider string decision AccountLinkingDecision }{ { // link decision because the below described identity is in the default linking domain but uses "other-provider" instead of "provder" - desc: "same email address", - email: "test@example.com", + desc: "same email address", + email: provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, sub: userA.ID.String(), provider: "other-provider", decision: LinkAccount, }, { - desc: "same email address in uppercase", - email: "TEST@example.com", + desc: "same email address in uppercase", + email: provider.Email{ + Email: "TEST@example.com", + Verified: true, + Primary: true, + }, sub: userA.ID.String(), provider: "other-provider", decision: LinkAccount, }, { - desc: "no link decision because the SSO linking domain is scoped to the provider unique ID", - email: "test@samltest.id", + desc: "no link decision because the SSO linking domain is scoped to the provider unique ID", + email: provider.Email{ + Email: "test@samltest.id", + Verified: true, + Primary: true, + }, sub: userB.ID.String(), provider: "sso:f06f9e3d-ff92-4c47-a179-7acf1fda6387", decision: AccountExists, @@ -157,9 +193,8 @@ func (ts *AccountLinkingTestSuite) TestLinkAccountExists() { for _, c := range cases { ts.Run(c.desc, func() { - decision, err := DetermineAccountLinking(ts.db, c.provider, c.sub, []string{c.email}) + decision, err := DetermineAccountLinking(ts.db, []provider.Email{c.email}, "authenticated", c.provider, c.sub) require.NoError(ts.T(), err) - require.Equal(ts.T(), decision.Decision, c.decision) }) } @@ -190,7 +225,13 @@ func (ts *AccountLinkingTestSuite) TestMultipleAccounts() { // decision is multiple accounts because there are two distinct // identities in the same "default" linking domain with the same email // address pointing to two different user accounts - decision, err := DetermineAccountLinking(ts.db, "provider", "abcdefgh", []string{"test@example.com"}) + decision, err := DetermineAccountLinking(ts.db, []provider.Email{ + provider.Email{ + Email: "test@example.com", + Verified: true, + Primary: true, + }, + }, "authenticated", "provider", "abcdefgh") require.NoError(ts.T(), err) require.Equal(ts.T(), decision.Decision, MultipleAccounts)