diff --git a/client/testdata/store.sql b/client/testdata/store.sql index ed539548613..597ee6a3c90 100644 --- a/client/testdata/store.sql +++ b/client/testdata/store.sql @@ -31,6 +31,9 @@ INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','','2024-10-0 INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,'0001-01-01 00:00:00+00:00','[]',0,0); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 21:28:24.830506+02:00','api',0,''); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); INSERT INTO installations VALUES(1,''); COMMIT; diff --git a/management/server/account.go b/management/server/account.go index 59c9c7fb089..268e1bdea42 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -66,7 +66,7 @@ func cacheEntryExpiration() time.Duration { } type AccountManager interface { - GetOrCreateAccountByUser(ctx context.Context, userId, domain string) (*Account, error) + GetOrCreateAccountIDByUser(ctx context.Context, userId, domain string) (string, error) GetAccount(ctx context.Context, accountID string) (*Account, error) CreateSetupKey(ctx context.Context, accountID string, keyName string, keyType SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*SetupKey, error) @@ -133,7 +133,7 @@ type AccountManager interface { GetDNSSettings(ctx context.Context, accountID string, userID string) (*DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) LoginPeer(ctx context.Context, login PeerLogin) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync PeerSync, accountID string) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) @@ -813,15 +813,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap { return groupList } -func (a *Account) getTakenIPs() []net.IP { - var takenIps []net.IP - for _, existingPeer := range a.Peers { - takenIps = append(takenIps, existingPeer.IP) - } - - return takenIps -} - func (a *Account) getPeerDNSLabels() lookupMap { existingLabels := make(lookupMap) for _, peer := range a.Peers { @@ -1048,39 +1039,21 @@ func BuildManager( metrics: metrics, requestBuffer: NewAccountRequestBuffer(ctx, store), } - allAccounts := store.GetAllAccounts(ctx) + totalAccounts, err := store.GetTotalAccounts(ctx) + if err != nil { + return nil, err + } + // enable single account mode only if configured by user and number of existing accounts is not grater than 1 - am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 + am.singleAccountMode = singleAccountModeDomain != "" && totalAccounts <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", totalAccounts) } else { - log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) - } - - // if account doesn't have a default group - // we create 'all' group and add all peers into it - // also we create default rule with source as destination - for _, account := range allAccounts { - shouldSave := false - - _, err := account.GetGroupAll() - if err != nil { - if err := addAllGroup(account); err != nil { - return nil, err - } - shouldSave = true - } - - if shouldSave { - err = store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", totalAccounts) } goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) @@ -1122,7 +1095,20 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. // Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Account, error) { +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *Settings) (*Settings, error) { + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return nil, err + } + + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() + } + + if !user.HasAdminPower() { + return nil, status.NewAdminPermissionError() + } + halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -1132,29 +1118,46 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() + var oldSettings *Settings - account, err := am.Store.GetAccount(ctx, accountID) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + oldSettings, err = transaction.GetAccountSettings(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + if err = am.validateExtraSettings(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + return err + } + + return transaction.SaveAccountSettings(ctx, LockingStrengthUpdate, accountID, newSettings) + }) if err != nil { return nil, err } - user, err := account.FindUser(userID) + am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) + + return newSettings, nil +} + +// validateExtraSettings validates the extra settings of the account. +func (am *DefaultAccountManager) validateExtraSettings(ctx context.Context, transaction Store, newSettings, oldSettings *Settings, userID, accountID string) error { + peers, err := transaction.GetAccountPeers(ctx, LockingStrengthShare, accountID) if err != nil { - return nil, err + return err } - if !user.HasAdminPower() { - return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer } - err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) - if err != nil { - return nil, err - } + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) +} - oldSettings := account.Settings +func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { @@ -1170,23 +1173,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - - err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) - if err != nil { - return nil, err - } - - updatedAccount := account.UpdateSettings(newSettings) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - return updatedAccount, nil } -func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *Settings, userID, accountID string) { if oldSettings.PeerInactivityExpirationEnabled != newSettings.PeerInactivityExpirationEnabled { event := activity.AccountPeerInactivityExpirationEnabled if !newSettings.PeerInactivityExpirationEnabled { @@ -1202,8 +1191,6 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context. am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } - - return nil } func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { @@ -1271,26 +1258,39 @@ func (am *DefaultAccountManager) checkAndSchedulePeerInactivityExpiration(ctx co // newAccount creates a new Account with a generated ID and generated default setup keys. // If ID is already in use (due to collision) we try one more time before returning error -func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (*Account, error) { - for i := 0; i < 2; i++ { - accountId := xid.New().String() - - _, err := am.Store.GetAccount(ctx, accountId) - statusErr, _ := status.FromError(err) - switch { - case err == nil: +func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain string) (string, error) { + var accountID string + + err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + for i := 0; i < 2; i++ { + accountID = xid.New().String() + + exists, err := transaction.AccountExists(ctx, LockingStrengthShare, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to check account existence: %v", err) + return err + } + + if !exists { + if err = newAccountWithId(ctx, transaction, accountID, userID, domain); err != nil { + log.WithContext(ctx).Errorf("failed to create new account: %v", err) + return err + } + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountCreated, nil) + + return nil + } + log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") - continue - case statusErr.Type() == status.NotFound: - newAccount := newAccountWithId(ctx, accountId, userID, domain) - am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) - return newAccount, nil - default: - return nil, err } + + return nil + }) + if err != nil { + return "", status.Errorf(status.Internal, "failed to create new account") } - return nil, status.Errorf(status.Internal, "error while creating new account") + return accountID, nil } func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { @@ -1404,15 +1404,15 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) + accountID, err = am.GetOrCreateAccountIDByUser(ctx, userID, domain) if err != nil { return "", status.Errorf(status.NotFound, "account not found or created for user id: %s", userID) } - if err = am.addAccountIDToIDPAppMeta(ctx, userID, account.Id); err != nil { + if err = am.addAccountIDToIDPAppMeta(ctx, userID, accountID); err != nil { return "", err } - return account.Id, nil + return accountID, nil } return "", err } @@ -1713,7 +1713,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx newCategoty = claims.DomainCategory } - return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) + return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategoty, &primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. @@ -1751,29 +1751,26 @@ func (am *DefaultAccountManager) addNewPrivateAccount(ctx context.Context, domai } lowerDomain := strings.ToLower(claims.Domain) + isPrimaryDomain := true - newAccount, err := am.newAccount(ctx, claims.UserId, lowerDomain) + newAccountID, err := am.newAccount(ctx, claims.UserId, lowerDomain) if err != nil { return "", err } - newAccount.Domain = lowerDomain - newAccount.DomainCategory = claims.DomainCategory - newAccount.IsDomainPrimaryAccount = true - - err = am.Store.SaveAccount(ctx, newAccount) + err = am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, newAccountID, lowerDomain, claims.DomainCategory, &isPrimaryDomain) if err != nil { return "", err } - err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccount.Id) + err = am.addAccountIDToIDPAppMeta(ctx, claims.UserId, newAccountID) if err != nil { return "", err } - am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccount.Id, activity.UserJoined, nil) + am.StoreEvent(ctx, claims.UserId, claims.UserId, newAccountID, activity.UserJoined, nil) - return newAccount.Id, nil + return newAccountID, nil } func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context, domainAccountID string, claims jwtclaims.AuthorizationClaims) (string, error) { @@ -2412,100 +2409,93 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account return nil, err } - if user.AccountID != accountID || (!user.HasAdminPower() && !user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "the user has no permission to access account data") + if user.AccountID != accountID { + return nil, status.NewUserNotPartOfAccountError() } - return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) -} - -// addAllGroup to account object if it doesn't exist -func addAllGroup(account *Account) error { - if len(account.Groups) == 0 { - allGroup := &nbgroup.Group{ - ID: xid.New().String(), - Name: "All", - Issued: nbgroup.GroupIssuedAPI, - } - for _, peer := range account.Peers { - allGroup.Peers = append(allGroup.Peers, peer.ID) - } - account.Groups = map[string]*nbgroup.Group{allGroup.ID: allGroup} - - id := xid.New().String() - - defaultPolicy := &Policy{ - ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, - Enabled: true, - Rules: []*PolicyRule{ - { - ID: id, - Name: DefaultRuleName, - Description: DefaultRuleDescription, - Enabled: true, - Sources: []string{allGroup.ID}, - Destinations: []string{allGroup.ID}, - Bidirectional: true, - Protocol: PolicyRuleProtocolALL, - Action: PolicyTrafficActionAccept, - }, - }, - } - - account.Policies = []*Policy{defaultPolicy} + if user.IsRegularUser() { + return nil, status.NewAdminPermissionError() } - return nil + + return am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) } -// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id -func newAccountWithId(ctx context.Context, accountID, userID, domain string) *Account { +// newAccountWithId initializes a new Account instance with the provided account ID, user ID, and domain. +// It creates default settings and establishes an initial user, group, and policy. +func newAccountWithId(ctx context.Context, transaction Store, accountID, userID, domain string) error { log.WithContext(ctx).Debugf("creating new account") - network := NewNetwork() - peers := make(map[string]*nbpeer.Peer) - users := make(map[string]*User) - routes := make(map[route.ID]*route.Route) - setupKeys := map[string]*SetupKey{} - nameServersGroups := make(map[string]*nbdns.NameServerGroup) + acc := &Account{ + Id: accountID, + CreatedAt: time.Now().UTC(), + Network: NewNetwork(), + CreatedBy: userID, + Domain: domain, + DNSSettings: DNSSettings{ + DisabledManagementGroups: make([]string, 0), + }, + Settings: &Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, + Extra: &account.ExtraSettings{ + PeerApprovalEnabled: false, + IntegratedValidatorGroups: make([]string, 0), + }, + }, + } + if err := transaction.CreateAccount(ctx, LockingStrengthUpdate, acc); err != nil { + return err + } owner := NewOwnerUser(userID) owner.AccountID = accountID - users[userID] = owner - - dnsSettings := DNSSettings{ - DisabledManagementGroups: make([]string, 0), + if err := transaction.SaveUser(ctx, LockingStrengthUpdate, owner); err != nil { + return err } - log.WithContext(ctx).Debugf("created new account %s", accountID) - acc := &Account{ - Id: accountID, - CreatedAt: time.Now().UTC(), - SetupKeys: setupKeys, - Network: network, - Peers: peers, - Users: users, - CreatedBy: userID, - Domain: domain, - Routes: routes, - NameServerGroups: nameServersGroups, - DNSSettings: dnsSettings, - Settings: &Settings{ - PeerLoginExpirationEnabled: true, - PeerLoginExpiration: DefaultPeerLoginExpiration, - GroupsPropagationEnabled: true, - RegularUsersViewBlocked: true, + allGroup := &nbgroup.Group{ + ID: xid.New().String(), + AccountID: accountID, + Name: "All", + Issued: nbgroup.GroupIssuedAPI, + } + if err := transaction.SaveGroup(ctx, LockingStrengthUpdate, allGroup); err != nil { + return err + } - PeerInactivityExpirationEnabled: false, - PeerInactivityExpiration: DefaultPeerInactivityExpiration, + policyID := xid.New().String() + defaultPolicy := &Policy{ + ID: policyID, + AccountID: accountID, + Name: DefaultPolicyName, + Description: DefaultPolicyDescription, + Enabled: true, + Rules: []*PolicyRule{ + { + ID: xid.New().String(), + PolicyID: policyID, + Name: DefaultRuleName, + Description: DefaultRuleDescription, + Enabled: true, + Sources: []string{allGroup.ID}, + Destinations: []string{allGroup.ID}, + Bidirectional: true, + Protocol: PolicyRuleProtocolALL, + Action: PolicyTrafficActionAccept, + }, }, } - - if err := addAllGroup(acc); err != nil { - log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err) + if err := transaction.CreatePolicy(ctx, LockingStrengthUpdate, defaultPolicy); err != nil { + return err } - return acc + + log.WithContext(ctx).Debugf("created new account %s", accountID) + + return nil } // extractJWTGroups extracts the group names from a JWT token's claims. diff --git a/management/server/account_test.go b/management/server/account_test.go index 650e8de6949..4bf73852340 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -401,7 +401,14 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } for _, testCase := range tt { - account := newAccountWithId(context.Background(), "account-1", userID, "netbird.io") + store := newStore(t) + + err := newAccountWithId(context.Background(), store, "account-1", userID, "netbird.io") + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), "account-1") + require.NoError(t, err, "failed to get account") + account.UpdateSettings(&testCase.accountSettings) account.Network = network account.Peers = testCase.peers @@ -419,6 +426,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) + + store.Close(context.Background()) } } @@ -426,27 +435,35 @@ func TestNewAccount(t *testing.T) { domain := "netbird.io" userId := "account_creator" accountID := "account_id" - account := newAccountWithId(context.Background(), accountID, userId, domain) + + store := newStore(t) + defer store.Close(context.Background()) + + err := newAccountWithId(context.Background(), store, accountID, userId, domain) + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") verifyNewAccountHasDefaultFields(t, account, userId, domain, []string{userId}) } -func TestAccountManager_GetOrCreateAccountByUser(t *testing.T) { +func TestAccountManager_GetOrCreateAccountIDByUser(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userID) return } - account, err = manager.Store.GetAccountByUser(context.Background(), userID) + account, err := manager.Store.GetAccountByUser(context.Background(), userID) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userID) return @@ -669,15 +686,12 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { userId := "user-id" domain := "test.domain" - _ = newAccountWithId(context.Background(), "", userId, domain) manager, err := createManager(t) require.NoError(t, err, "unable to create account manager") accountID, err := manager.GetAccountIDByUserID(context.Background(), userId, domain) require.NoError(t, err, "create init user failed") - // as initAccount was created without account id we have to take the id after account initialization - // that happens inside the GetAccountIDByUserID where the id is getting generated - // it is important to set the id as it help to avoid creating additional account with empty Id and re-pointing indices to it + initAccount, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "get init account failed") @@ -693,44 +707,53 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "only ALL group should exists") + require.Len(t, accountGroups, 1, "only ALL group should exists") }) t.Run("JWT groups enabled without claim name", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + totalAccounts, err := manager.Store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + accountGroups, err := manager.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account groups") - require.Len(t, account.Groups, 1, "if group claim is not set no group added from JWT") + require.Len(t, accountGroups, 1, "if group claim is not set no group added from JWT") }) t.Run("JWT groups enabled", func(t *testing.T) { initAccount.Settings.JWTGroupsEnabled = true initAccount.Settings.JWTGroupsClaimName = "idp-groups" - err := manager.Store.SaveAccount(context.Background(), initAccount) - require.NoError(t, err, "save account failed") - require.Len(t, manager.Store.GetAllAccounts(context.Background()), 1, "only one account should exist") + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userId, initAccount.Settings) + require.NoError(t, err, "failed to update account settings") + + totalAccounts, err := manager.Store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "only one account should exist") accountID, _, err := manager.GetAccountIDFromToken(context.Background(), claims) require.NoError(t, err, "get account by token failed") - account, err := manager.Store.GetAccount(context.Background(), accountID) - require.NoError(t, err, "get account failed") + exists, err := manager.Store.AccountExists(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to check account existence") + require.True(t, exists, "account should exist") - require.Len(t, account.Groups, 3, "groups should be added to the account") + accountGroups, err := manager.GetAllGroups(context.Background(), accountID, userId) + require.NoError(t, err, "failed to get account groups") + require.Len(t, accountGroups, 3, "groups should be added to the account") groupsByNames := map[string]*group.Group{} - for _, g := range account.Groups { + for _, g := range accountGroups { groupsByNames[g.Name] = g } @@ -746,27 +769,23 @@ func TestDefaultAccountManager_GetGroupsFromTheToken(t *testing.T) { }) } -func TestAccountManager_GetAccountFromPAT(t *testing.T) { +func TestAccountManager_GetAccountInfoFromPAT(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - UserID: "someUser", - HashedToken: encodedHashedToken, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "testuser", + HashedToken: encodedHashedToken, + CreatedAt: time.Now().UTC(), } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, @@ -778,31 +797,27 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) { } assert.Equal(t, "account_id", user.AccountID) - assert.Equal(t, "someUser", user.Id) - assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID) + assert.Equal(t, "testuser", user.Id) + assert.Equal(t, userPAT, pat) } func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + err := newAccountWithId(context.Background(), store, "account_id", "testuser", "") + require.NoError(t, err, "failed to create account") token := "nbp_9999EUDNdkeusjentDLSJEn1902u84390W6W" hashedToken := sha256.Sum256([]byte(token)) encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:]) - account.Users["someUser"] = &User{ - Id: "someUser", - PATs: map[string]*PersonalAccessToken{ - "tokenId": { - ID: "tokenId", - HashedToken: encodedHashedToken, - LastUsed: time.Time{}, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) + + userPAT := &PersonalAccessToken{ + ID: "tokenId", + UserID: "someUser", + HashedToken: encodedHashedToken, + LastUsed: time.Time{}, } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, userPAT) + require.NoError(t, err, "failed to save PAT") am := DefaultAccountManager{ Store: store, @@ -813,11 +828,10 @@ func TestDefaultAccountManager_MarkPATUsed(t *testing.T) { t.Fatalf("Error when marking PAT used: %s", err) } - account, err = am.Store.GetAccount(context.Background(), "account_id") - if err != nil { - t.Fatalf("Error when getting account: %s", err) - } - assert.True(t, !account.Users["someUser"].PATs["tokenId"].LastUsed.IsZero()) + userPAT, err = store.GetPATByID(context.Background(), LockingStrengthShare, userPAT.UserID, userPAT.ID) + require.NoError(t, err, "failed to get PAT") + + assert.True(t, !userPAT.LastUsed.IsZero()) } func TestAccountManager_PrivateAccount(t *testing.T) { @@ -828,15 +842,15 @@ func TestAccountManager_PrivateAccount(t *testing.T) { } userId := "test_user" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, "") if err != nil { t.Fatal(err) } - if account == nil { + if accountID == "" { t.Fatalf("expected to create an account for a user %s", userId) } - account, err = manager.Store.GetAccountByUser(context.Background(), userId) + account, err := manager.Store.GetAccountByUser(context.Background(), userId) if err != nil { t.Errorf("expected to get existing account after creation, no account was found for a user %s", userId) } @@ -855,32 +869,22 @@ func TestAccountManager_SetOrUpdateDomain(t *testing.T) { userId := "test_user" domain := "hotmail.com" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatal(err) - } - if account == nil { - t.Fatalf("expected to create an account for a user %s", userId) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") + require.NotEmptyf(t, accountID, "expected to create an account for a user %s", userId) - if account != nil && account.Domain != domain { - t.Errorf("setting account domain failed, expected %s, got %s", domain, account.Domain) - } + accDomain, _, err := manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") domain = "gmail.com" - account, err = manager.GetOrCreateAccountByUser(context.Background(), userId, domain) - if err != nil { - t.Fatalf("got the following error while retrieving existing acc: %v", err) - } - - if account == nil { - t.Fatalf("expected to get an account for a user %s", userId) - } + accountID, err = manager.GetOrCreateAccountIDByUser(context.Background(), userId, domain) + require.NoError(t, err, "failed to get or create account by user") - if account != nil && account.Domain != domain { - t.Errorf("updating domain. expected %s got %s", domain, account.Domain) - } + accDomain, _, err = manager.Store.GetAccountDomainAndCategory(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account domain and category") + require.Equal(t, domain, accDomain, "expected account domain to match") } func TestAccountManager_GetAccountByUserID(t *testing.T) { @@ -912,12 +916,11 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) { } func createAccount(am *DefaultAccountManager, accountID, userID, domain string) (*Account, error) { - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { return nil, err } - return account, nil + return am.Store.GetAccount(context.Background(), accountID) } func TestAccountManager_GetAccount(t *testing.T) { @@ -1164,23 +1167,18 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { return } - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "netbird.cloud") - if err != nil { - t.Fatal(err) - } + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "netbird.cloud") + require.NoError(t, err, "failed to get or create account by user") - serial := account.Network.CurrentSerial() // should be 0 + network, err := manager.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account network") - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } + serial := network.CurrentSerial() // should be 0 + require.Equal(t, 0, int(serial), "expected account network to have an initial Serial=0") key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return - } + require.NoError(t, err, "failed to generate private key") + expectedPeerKey := key.PublicKey().String() expectedUserID := userID @@ -1188,16 +1186,10 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { Key: expectedPeerKey, Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, }) - if err != nil { - t.Errorf("expecting peer to be added, got failure %v, account users: %v", err, account.CreatedBy) - return - } + require.NoError(t, err, "failed to add peer") - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - return - } + account, err := manager.Store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") if peer.Key != expectedPeerKey { t.Errorf("expecting just added peer to have key = %s, got %s", expectedPeerKey, peer.Key) @@ -1856,10 +1848,12 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: true, - }) + settings, err := manager.GetAccountSettings(context.Background(), accountID, userID) + require.NoError(t, err, "unable to get account settings") + + settings.PeerLoginExpirationEnabled = true + settings.PeerLoginExpiration = time.Hour + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") wg := &sync.WaitGroup{} @@ -1876,11 +1870,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { // disable expiration first update := peer.Copy() update.LoginExpirationEnabled = false - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") // enabling expiration should trigger the routine update.LoginExpirationEnabled = true - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") failed := waitTimeout(wg, time.Second) @@ -1904,10 +1898,14 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. LoginExpirationEnabled: true, }) require.NoError(t, err, "unable to add peer") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: true, - }) + + settings, err := manager.GetAccountSettings(context.Background(), accountID, userID) + require.NoError(t, err, "unable to get account settings") + + settings.PeerLoginExpirationEnabled = true + settings.PeerLoginExpiration = time.Hour + + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") wg := &sync.WaitGroup{} @@ -1969,11 +1967,15 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Done() }, } + // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: true, - }) + settings, err := manager.GetAccountSettings(context.Background(), accountID, userID) + require.NoError(t, err, "unable to get account settings") + + settings.PeerLoginExpirationEnabled = true + settings.PeerLoginExpiration = time.Hour + + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") failed := waitTimeout(wg, time.Second) @@ -1983,10 +1985,8 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test wg.Add(1) // disabling PeerLoginExpirationEnabled should trigger cancel - _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpirationEnabled = false + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") failed = waitTimeout(wg, time.Second) if failed { @@ -2001,30 +2001,29 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour, - PeerLoginExpirationEnabled: false, - }) + settings, err := manager.GetAccountSettings(context.Background(), accountID, userID) + require.NoError(t, err, "unable to get account settings") + + settings.PeerLoginExpirationEnabled = false + settings.PeerLoginExpiration = time.Hour + + updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.NoError(t, err, "expecting to update account settings successfully but got error") - assert.False(t, updated.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) + assert.False(t, updatedSettings.PeerLoginExpirationEnabled) + assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) - settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + settings, err = manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") assert.False(t, settings.PeerLoginExpirationEnabled) assert.Equal(t, settings.PeerLoginExpiration, time.Hour) - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Second, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpiration = time.Second + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") - _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &Settings{ - PeerLoginExpiration: time.Hour * 24 * 181, - PeerLoginExpirationEnabled: false, - }) + settings.PeerLoginExpiration = time.Hour * 24 * 181 + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, settings) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c0f..b50a67594a2 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -39,12 +39,12 @@ func TestGetDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Fatal("failed to init testing account") } - dnsSettings, err := am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err := am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Fatalf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -53,16 +53,12 @@ func TestGetDNSSettings(t *testing.T) { t.Fatal("DNS settings for new accounts shouldn't return nil") } - account.DNSSettings = DNSSettings{ + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &DNSSettings{ DisabledManagementGroups: []string{group1ID}, - } - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save testing account with new DNS settings") - } + }) + require.NoError(t, err, "failed to update DNS settings") - dnsSettings, err = am.GetDNSSettings(context.Background(), account.Id, dnsAdminUserID) + dnsSettings, err = am.GetDNSSettings(context.Background(), accountID, dnsAdminUserID) if err != nil { t.Errorf("Got an error when trying to retrieve the DNS settings with an admin user, err: %s", err) } @@ -71,7 +67,7 @@ func TestGetDNSSettings(t *testing.T) { t.Errorf("DNS settings should have one disabled mgmt group, groups: %s", dnsSettings.DisabledManagementGroups) } - _, err = am.GetDNSSettings(context.Background(), account.Id, dnsRegularUserID) + _, err = am.GetDNSSettings(context.Background(), accountID, dnsRegularUserID) if err == nil { t.Errorf("An error should be returned when getting the DNS settings with a regular user") } @@ -126,12 +122,12 @@ func TestSaveDNSSettings(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - err = am.SaveDNSSettings(context.Background(), account.Id, testCase.userID, testCase.inputSettings) + err = am.SaveDNSSettings(context.Background(), accountID, testCase.userID, testCase.inputSettings) if err != nil { if testCase.shouldFail { return @@ -139,7 +135,7 @@ func TestSaveDNSSettings(t *testing.T) { t.Error(err) } - updatedAccount, err := am.Store.GetAccount(context.Background(), account.Id) + updatedAccount, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { t.Errorf("should be able to retrieve updated account, got err: %s", err) } @@ -158,17 +154,17 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestDNSAccount(t, am) + accountID, err := initTestDNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - peer1, err := account.FindPeerByPubKey(dnsPeer1Key) + peer1, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer1Key) if err != nil { t.Error("failed to init testing account") } - peer2, err := account.FindPeerByPubKey(dnsPeer2Key) + peer2, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, dnsPeer2Key) if err != nil { t.Error("failed to init testing account") } @@ -179,11 +175,13 @@ func TestGetNetworkMap_DNSConfigSync(t *testing.T) { require.True(t, newAccountDNSConfig.DNSConfig.ServiceEnable, "default DNS config should have local DNS service enabled") require.Len(t, newAccountDNSConfig.DNSConfig.NameServerGroups, 0, "updated DNS config should have no nameserver groups since peer 1 is NS for the only existing NS group") - dnsSettings := account.DNSSettings.Copy() + accountDNSSettings, err := am.Store.GetAccountDNSSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account DNS settings") + + dnsSettings := accountDNSSettings.Copy() dnsSettings.DisabledManagementGroups = append(dnsSettings.DisabledManagementGroups, dnsGroup1ID) - account.DNSSettings = dnsSettings - err = am.Store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = am.Store.SaveDNSSettings(context.Background(), LockingStrengthUpdate, accountID, &dnsSettings) + require.NoError(t, err, "failed to update DNS settings") updatedAccountDNSConfig, err := am.GetNetworkMap(context.Background(), peer1.ID) require.NoError(t, err) @@ -222,7 +220,7 @@ func createDNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() peer1 := &nbpeer.Peer{ Key: dnsPeer1Key, @@ -257,64 +255,65 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro domain := "example.com" - account := newAccountWithId(context.Background(), dnsAccountID, dnsAdminUserID, domain) - - account.Users[dnsRegularUserID] = &User{ - Id: dnsRegularUserID, - Role: UserRoleUser, + err := newAccountWithId(context.Background(), am.Store, dnsAccountID, dnsAdminUserID, domain) + if err != nil { + return "", err } - err := am.Store.SaveAccount(context.Background(), account) + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: dnsRegularUserID, + AccountID: dnsAccountID, + Role: UserRoleUser, + }) if err != nil { - return nil, err + return "", err } savedPeer1, _, _, err := am.AddPeer(context.Background(), "", dnsAdminUserID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", dnsAdminUserID, peer2) if err != nil { - return nil, err + return "", err } - account, err = am.Store.GetAccount(context.Background(), account.Id) + peer1, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer1.Key) if err != nil { - return nil, err + return "", err } - peer1, err = account.FindPeerByPubKey(peer1.Key) + _, err = am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peer2.Key) if err != nil { - return nil, err + return "", err } - _, err = account.FindPeerByPubKey(peer2.Key) + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{ + { + ID: dnsGroup1ID, + AccountID: dnsAccountID, + Peers: []string{peer1.ID}, + Name: dnsGroup1ID, + }, + { + ID: dnsGroup2ID, + AccountID: dnsAccountID, + Name: dnsGroup2ID, + }, + }) if err != nil { - return nil, err - } - - newGroup1 := &group.Group{ - ID: dnsGroup1ID, - Peers: []string{peer1.ID}, - Name: dnsGroup1ID, + return "", err } - newGroup2 := &group.Group{ - ID: dnsGroup2ID, - Name: dnsGroup2ID, - } - - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - allGroup, err := account.GetGroupAll() + allGroup, err := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, dnsAccountID, "All") if err != nil { - return nil, err + return "", err } - account.NameServerGroups[dnsNSGroup1] = &dns.NameServerGroup{ - ID: dnsNSGroup1, - Name: "ns-group-1", + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &dns.NameServerGroup{ + ID: dnsNSGroup1, + AccountID: dnsAccountID, + Name: "ns-group-1", NameServers: []dns.NameServer{{ IP: netip.MustParseAddr(savedPeer1.IP.String()), NSType: dns.UDPNameServerType, @@ -323,14 +322,12 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro Primary: true, Enabled: true, Groups: []string{allGroup.ID}, - } - - err = am.Store.SaveAccount(context.Background(), account) + }) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return dnsAccountID, nil } func generateTestData(size int) nbdns.Config { diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 00e5d777a79..eae70dae506 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,21 +7,12 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/stretchr/testify/require" ) type MockStore struct { Store - account *Account -} - -func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ LockingStrength) ([]*nbpeer.Peer, error) { - var peers []*nbpeer.Peer - for _, v := range s.account.Peers { - if v.Ephemeral { - peers = append(peers, v) - } - } - return peers, nil + accountID string } type MocAccountManager struct { @@ -29,9 +20,8 @@ type MocAccountManager struct { store *MockStore } -func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error { - delete(a.store.account.Peers, peerID) - return nil //nolint:nil +func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, _ string) error { + return a.store.DeletePeer(context.Background(), LockingStrengthUpdate, accountID, peerID) } func TestNewManager(t *testing.T) { @@ -40,23 +30,26 @@ func TestNewManager(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - if len(store.account.Peers) != numberOfPeers { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerConnected(t *testing.T) { @@ -65,26 +58,32 @@ func TestNewManagerPeerConnected(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + + mgr.OnPeerConnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) - expected := numberOfPeers + 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + require.Equal(t, numberOfPeers+1, len(peers), "failed to cleanup ephemeral peers") } func TestNewManagerPeerDisconnected(t *testing.T) { @@ -93,50 +92,73 @@ func TestNewManagerPeerDisconnected(t *testing.T) { return startTime } - store := &MockStore{} + store := &MockStore{ + Store: newStore(t), + } am := MocAccountManager{ store: store, } numberOfPeers := 5 numberOfEphemeralPeers := 3 - seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + err := seedPeers(store, numberOfPeers, numberOfEphemeralPeers) + require.NoError(t, err, "failed to seed peers") mgr := NewEphemeralManager(store, am) mgr.loadEphemeralPeers(context.Background()) - for _, v := range store.account.Peers { - mgr.OnPeerConnected(context.Background(), v) + peers, err := store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") + for _, v := range peers { + mgr.OnPeerConnected(context.Background(), v) } - mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"]) + + peer, err := am.store.GetPeerByID(context.Background(), LockingStrengthShare, store.accountID, "ephemeral_peer_0") + require.NoError(t, err, "failed to get peer") + mgr.OnPeerDisconnected(context.Background(), peer) startTime = startTime.Add(ephemeralLifeTime + 1) mgr.cleanup(context.Background()) + peers, err = store.GetAccountPeers(context.Background(), LockingStrengthShare, store.accountID) + require.NoError(t, err, "failed to get account peers") expected := numberOfPeers + numberOfEphemeralPeers - 1 - if len(store.account.Peers) != expected { - t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers)) - } + require.Equal(t, expected, len(peers), "failed to cleanup ephemeral peers") } -func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) { - store.account = newAccountWithId(context.Background(), "my account", "", "") +func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) error { + accountID := "my account" + err := newAccountWithId(context.Background(), store, accountID, "", "") + if err != nil { + return err + } + store.accountID = accountID for i := 0; i < numberOfPeers; i++ { peerId := fmt.Sprintf("peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: false, } - store.account.Peers[p.ID] = p + err = store.AddPeerToAccount(context.Background(), p) + if err != nil { + return err + } } for i := 0; i < numberOfEphemeralPeers; i++ { peerId := fmt.Sprintf("ephemeral_peer_%d", i) p := &nbpeer.Peer{ ID: peerId, + AccountID: accountID, Ephemeral: true, } - store.account.Peers[p.ID] = p + err = store.AddPeerToAccount(context.Background(), p) + if err != nil { + return err + } } + + return nil } diff --git a/management/server/group_test.go b/management/server/group_test.go index 0515b9698ee..ca48441dd46 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -328,25 +328,30 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A } routeResource := &route.Route{ - ID: "example route", - Groups: []string{groupForRoute.ID}, + ID: "example route", + AccountID: accountID, + Groups: []string{groupForRoute.ID}, } routePeerGroupResource := &route.Route{ ID: "example route with peer groups", + AccountID: accountID, PeerGroups: []string{groupForRoute2.ID}, } nameServerGroup := &nbdns.NameServerGroup{ - ID: "example name server group", - Groups: []string{groupForNameServerGroups.ID}, + ID: "example name server group", + AccountID: accountID, + Groups: []string{groupForNameServerGroups.ID}, } policy := &Policy{ - ID: "example policy", + ID: "example policy", + AccountID: accountID, Rules: []*PolicyRule{ { ID: "example policy rule", + PolicyID: "example policy", Destinations: []string{groupForPolicies.ID}, }, }, @@ -354,35 +359,60 @@ func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *A setupKey := &SetupKey{ Id: "example setup key", + AccountID: accountID, AutoGroups: []string{groupForSetupKeys.ID}, } user := &User{ Id: "example user", + AccountID: accountID, AutoGroups: []string{groupForUsers.ID}, } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Routes[routeResource.ID] = routeResource - account.Routes[routePeerGroupResource.ID] = routePeerGroupResource - account.NameServerGroups[nameServerGroup.ID] = nameServerGroup - account.Policies = append(account.Policies, policy) - account.SetupKeys[setupKey.Id] = setupKey - account.Users[user.Id] = user - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) if err != nil { return nil, nil, err } - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute2) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForNameServerGroups) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForPolicies) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForSetupKeys) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers) - _ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration) + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routeResource) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, routePeerGroupResource) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameServerGroup) + if err != nil { + return nil, nil, err + } + + err = am.Store.CreatePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return nil, nil, err + } + + err = am.Store.SaveUser(context.Background(), LockingStrengthUpdate, user) + if err != nil { + return nil, nil, err + } + + err = am.SaveGroups(context.Background(), accountID, groupAdminUserID, []*nbgroup.Group{ + groupForRoute, groupForRoute2, groupForNameServerGroups, groupForPolicies, + groupForSetupKeys, groupForUsers, groupForIntegration, + }) + if err != nil { + return nil, nil, err + } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + acc, err := am.Store.GetAccount(context.Background(), accountID) if err != nil { return nil, nil, err } diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index 4baf9c6925f..b418ae02b45 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -100,13 +100,13 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings) + resp := toAccountResponse(accountID, updatedSettings) util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index cacb3d43010..b70d20a7deb 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -29,7 +29,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*server.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -39,9 +39,7 @@ func initAccountsTestData(account *server.Account, admin *server.User) *Accounts return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - accCopy := account.Copy() - accCopy.UpdateSettings(newSettings) - return accCopy, nil + return newSettings.Copy(), nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 57ad968b3d7..dc8765e197f 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -246,7 +246,7 @@ func Test_SyncProtocol(t *testing.T) { t.Fatal("expecting SyncResponse to have non-nil NetworkMap") } - if len(networkMap.GetRemotePeers()) != 4 { + if len(networkMap.GetRemotePeers()) != 3 { t.Fatalf("expecting SyncResponse to have NetworkMap with 3 remote peers, got %d", len(networkMap.GetRemotePeers())) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3e465e32e3d..9889552b8a4 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -22,9 +22,9 @@ import ( ) type MockAccountManager struct { - GetOrCreateAccountByUserFunc func(ctx context.Context, userId, domain string) (*server.Account, error) - GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) - CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, + GetOrCreateAccountIDByUserFunc func(ctx context.Context, userId, domain string) (string, error) + GetAccountFunc func(ctx context.Context, accountID string) (*server.Account, error) + CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool) (*server.SetupKey, error) GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*server.SetupKey, error) AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) @@ -89,7 +89,7 @@ type MockAccountManager struct { GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*server.DNSSettings, error) SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *server.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) LoginPeerFunc func(ctx context.Context, login server.PeerLogin) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync server.PeerSync, accountID string) (*nbpeer.Peer, *server.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error @@ -176,16 +176,16 @@ func (am *MockAccountManager) DeletePeer(ctx context.Context, accountID, peerID, return status.Errorf(codes.Unimplemented, "method DeletePeer is not implemented") } -// GetOrCreateAccountByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface -func (am *MockAccountManager) GetOrCreateAccountByUser( +// GetOrCreateAccountIDByUser mock implementation of GetOrCreateAccountByUser from server.AccountManager interface +func (am *MockAccountManager) GetOrCreateAccountIDByUser( ctx context.Context, userId, domain string, -) (*server.Account, error) { - if am.GetOrCreateAccountByUserFunc != nil { - return am.GetOrCreateAccountByUserFunc(ctx, userId, domain) +) (string, error) { + if am.GetOrCreateAccountIDByUserFunc != nil { + return am.GetOrCreateAccountIDByUserFunc(ctx, userId, domain) } - return nil, status.Errorf( + return "", status.Errorf( codes.Unimplemented, - "method GetOrCreateAccountByUser is not implemented", + "method GetOrCreateAccountIDByUser is not implemented", ) } @@ -672,7 +672,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Settings, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf02370..6a305e723d3 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -381,14 +382,14 @@ func TestCreateNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } outNSGroup, err := am.CreateNameServerGroup( context.Background(), - account.Id, + accountID, testCase.inputArgs.name, testCase.inputArgs.description, testCase.inputArgs.nameServers, @@ -609,20 +610,16 @@ func TestSaveNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + testCase.existingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testCase.existingNSGroup) + require.NoError(t, err, "failed to save existing nameserver group") var nsGroupToSave *nbdns.NameServerGroup - if !testCase.skipCopying { nsGroupToSave = testCase.existingNSGroup.Copy() @@ -651,22 +648,17 @@ func TestSaveNameServerGroup(t *testing.T) { } } - err = am.SaveNameServerGroup(context.Background(), account.Id, userID, nsGroupToSave) - + err = am.SaveNameServerGroup(context.Background(), accountID, userID, nsGroupToSave) testCase.errFunc(t, err) if !testCase.shouldCreate { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedNSGroup, saved := account.NameServerGroups[testCase.expectedNSGroup.ID] - require.True(t, saved) + savedNSGroup, err := am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testCase.expectedNSGroup.ID) + require.NoError(t, err, "failed to get saved nameserver group") + testCase.expectedNSGroup.AccountID = accountID if !testCase.expectedNSGroup.IsEqual(savedNSGroup) { t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", savedNSGroup, testCase.expectedNSGroup) } @@ -703,32 +695,25 @@ func TestDeleteNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.NameServerGroups[testingNSGroup.ID] = testingNSGroup - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") - } + testingNSGroup.AccountID = accountID + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, testingNSGroup) + require.NoError(t, err, "failed to save nameserver group") - err = am.DeleteNameServerGroup(context.Background(), account.Id, testingNSGroup.ID, userID) + err = am.DeleteNameServerGroup(context.Background(), accountID, testingNSGroup.ID, userID) if err != nil { t.Error("deleting nameserver group failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } - - _, found := savedAccount.NameServerGroups[testingNSGroup.ID] - if found { - t.Error("nameserver group shouldn't be found after delete") - } + _, err = am.Store.GetNameServerGroupByID(context.Background(), LockingStrengthShare, accountID, testingNSGroup.ID) + require.NotNil(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok, "error should be a status error") + assert.Equal(t, status.NotFound, sErr.Type(), "nameserver group shouldn't be found after delete") } func TestGetNameServerGroup(t *testing.T) { @@ -738,12 +723,12 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestNSAccount(t, am) + accountID, err := initTestNSAccount(t, am) if err != nil { t.Error("failed to init testing account") } - foundGroup, err := am.GetNameServerGroup(context.Background(), account.Id, testUserID, existingNSGroupID) + foundGroup, err := am.GetNameServerGroup(context.Background(), accountID, testUserID, existingNSGroupID) if err != nil { t.Error("getting existing nameserver group failed with error: ", err) } @@ -752,7 +737,7 @@ func TestGetNameServerGroup(t *testing.T) { t.Error("got a nil group while getting nameserver group with ID") } - _, err = am.GetNameServerGroup(context.Background(), account.Id, testUserID, "not existing") + _, err = am.GetNameServerGroup(context.Background(), accountID, testUserID, "not existing") if err == nil { t.Error("getting not existing nameserver group should return error, got nil") } @@ -784,8 +769,12 @@ func createNSStore(t *testing.T) (Store, error) { return store, nil } -func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() + accountID := "testingAcc" + userID := testUserID + domain := "example.com" + peer1 := &nbpeer.Peer{ Key: nsGroupPeer1Key, Name: "test-host1@netbird.io", @@ -816,6 +805,7 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error } existingNSGroup := nbdns.NameServerGroup{ ID: existingNSGroupID, + AccountID: accountID, Name: existingNSGroupName, Description: "", NameServers: []nbdns.NameServer{ @@ -834,42 +824,42 @@ func initTestNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, error Enabled: true, } - accountID := "testingAcc" - userID := testUserID - domain := "example.com" - - account := newAccountWithId(context.Background(), accountID, userID, domain) - - account.NameServerGroups[existingNSGroup.ID] = &existingNSGroup - - newGroup1 := &nbgroup.Group{ - ID: group1ID, - Name: group1ID, + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) + if err != nil { + return "", err } - newGroup2 := &nbgroup.Group{ - ID: group2ID, - Name: group2ID, + err = am.Store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, &existingNSGroup) + if err != nil { + return "", err } - account.Groups[newGroup1.ID] = newGroup1 - account.Groups[newGroup2.ID] = newGroup2 - - err := am.Store.SaveAccount(context.Background(), account) + err = am.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*nbgroup.Group{ + { + ID: group1ID, + AccountID: accountID, + Name: group1ID, + }, + { + ID: group2ID, + AccountID: accountID, + Name: group2ID, + }, + }) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer1) if err != nil { - return nil, err + return "", err } _, _, _, err = am.AddPeer(context.Background(), "", userID, peer2) if err != nil { - return nil, err + return "", err } - return account, nil + return accountID, nil } func TestValidateDomain(t *testing.T) { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 0e30a3762e2..fc63156c35a 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -468,21 +468,25 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ - Id: someUser, - Role: UserRoleUser, - } - account.Settings.RegularUsersViewBlocked = false + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return - } + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: someUser, + AccountID: accountID, + Role: UserRoleUser, + }) + require.NoError(t, err, "failed to create user") + + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = false + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") // two peers one added by a regular user and one with a setup key - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) + setupKey, err := manager.CreateSetupKey(context.Background(), accountID, "test-key", SetupKeyReusable, time.Hour, nil, 999, adminUser, false) if err != nil { t.Fatal("error creating setup key") return @@ -536,7 +540,10 @@ func TestDefaultAccountManager_GetPeer(t *testing.T) { assert.NotNil(t, peer) // delete the all-to-all policy so that user's peer1 has no access to peer2 - for _, policy := range account.Policies { + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) if err != nil { t.Fatal(err) @@ -655,21 +662,33 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) { accountID := "test_account" adminUser := "account_creator" someUser := "some_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[someUser] = &User{ + + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + require.NoError(t, err, "failed to create account") + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: someUser, + AccountID: accountID, Role: testCase.role, IsServiceUser: testCase.isServiceUser, - } - account.Policies = []*Policy{} - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings + }) + require.NoError(t, err, "failed to create user") - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) - return + accountPolicies, err := manager.Store.GetAccountPolicies(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account policies") + + for _, policy := range accountPolicies { + err = manager.DeletePolicy(context.Background(), accountID, policy.ID, adminUser) + require.NoError(t, err, "failed to delete policy") } + settings, err := manager.Store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + err = manager.Store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err, "failed to save account settings") + peerKey1, err := wgtypes.GeneratePrivateKey() if err != nil { t.Fatal(err) @@ -725,10 +744,18 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou adminUser := "account_creator" regularUser := "regular_user" - account := newAccountWithId(context.Background(), accountID, adminUser, "") - account.Users[regularUser] = &User{ - Id: regularUser, - Role: UserRoleUser, + err = newAccountWithId(context.Background(), manager.Store, accountID, adminUser, "") + if err != nil { + return nil, "", "", err + } + + err = manager.Store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: regularUser, + AccountID: accountID, + Role: UserRoleUser, + }) + if err != nil { + return nil, "", "", err } // Create peers @@ -742,31 +769,40 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou Status: &nbpeer.PeerStatus{}, UserID: regularUser, } - account.Peers[peer.ID] = peer + err = manager.Store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, peer) + if err != nil { + return nil, "", "", err + } } // Create groups and policies - account.Policies = make([]*Policy, 0, groups) for i := 0; i < groups; i++ { groupID := fmt.Sprintf("group-%d", i) group := &nbgroup.Group{ - ID: groupID, - Name: fmt.Sprintf("Group %d", i), + ID: groupID, + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), } for j := 0; j < peers/groups; j++ { peerIndex := i*(peers/groups) + j group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex)) } - account.Groups[groupID] = group + + err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + if err != nil { + return nil, "", "", err + } // Create a policy for this group policy := &Policy{ - ID: fmt.Sprintf("policy-%d", i), - Name: fmt.Sprintf("Policy for Group %d", i), - Enabled: true, + ID: fmt.Sprintf("policy-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Policy for Group %d", i), + Enabled: true, Rules: []*PolicyRule{ { ID: fmt.Sprintf("rule-%d", i), + PolicyID: fmt.Sprintf("policy-%d", i), Name: fmt.Sprintf("Rule for Group %d", i), Enabled: true, Sources: []string{groupID}, @@ -777,22 +813,23 @@ func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccou }, }, } - account.Policies = append(account.Policies, policy) + + err = manager.Store.SavePolicy(context.Background(), LockingStrengthUpdate, policy) + if err != nil { + return nil, "", "", err + } } - account.PostureChecks = []*posture.Checks{ - { - ID: "PostureChecksAll", - Name: "All", - Checks: posture.ChecksDefinition{ - NBVersionCheck: &posture.NBVersionCheck{ - MinVersion: "0.0.1", - }, + err = manager.Store.SavePostureChecks(context.Background(), LockingStrengthUpdate, &posture.Checks{ + ID: "PostureChecksAll", + AccountID: accountID, + Name: "All", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.0.1", }, }, - } - - err = manager.Store.SaveAccount(context.Background(), account) + }) if err != nil { return nil, "", "", err } diff --git a/management/server/personal_access_token.go b/management/server/personal_access_token.go index e4b19da76a5..a135ad4af50 100644 --- a/management/server/personal_access_token.go +++ b/management/server/personal_access_token.go @@ -41,6 +41,7 @@ type PersonalAccessToken struct { func (t *PersonalAccessToken) Copy() *PersonalAccessToken { return &PersonalAccessToken{ ID: t.ID, + UserID: t.UserID, Name: t.Name, HashedToken: t.HashedToken, ExpirationDate: t.ExpirationDate, diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index 93e5741cf28..aa99188eb5f 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -25,22 +25,22 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestPostureChecksAccount(am) + accountID, err := initTestPostureChecksAccount(am) if err != nil { t.Error("failed to init testing account") } t.Run("Generic posture check flow", func(t *testing.T) { // regular users can not create checks - _, err = am.SavePostureChecks(context.Background(), account.Id, regularUserID, &posture.Checks{}) + _, err = am.SavePostureChecks(context.Background(), accountID, regularUserID, &posture.Checks{}) assert.Error(t, err) // regular users cannot list check - _, err = am.ListPostureChecks(context.Background(), account.Id, regularUserID) + _, err = am.ListPostureChecks(context.Background(), accountID, regularUserID) assert.Error(t, err) // should be possible to create posture check with uniq name - postureCheck, err := am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + postureCheck, err := am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{ @@ -51,12 +51,12 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { assert.NoError(t, err) // admin users can list check - checks, err := am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err := am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 1) // should not be possible to create posture check with non uniq name - _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, &posture.Checks{ + _, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, &posture.Checks{ Name: postureCheckName, Checks: posture.ChecksDefinition{ GeoLocationCheck: &posture.GeoLocationCheck{ @@ -76,45 +76,48 @@ func TestDefaultAccountManager_PostureCheck(t *testing.T) { MinVersion: "0.27.0", }, } - _, err = am.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheck) + _, err = am.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheck) assert.NoError(t, err) // users should not be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, regularUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, regularUserID) assert.Error(t, err) // admin should be able to delete posture checks - err = am.DeletePostureChecks(context.Background(), account.Id, postureCheck.ID, adminUserID) + err = am.DeletePostureChecks(context.Background(), accountID, postureCheck.ID, adminUserID) assert.NoError(t, err) - checks, err = am.ListPostureChecks(context.Background(), account.Id, adminUserID) + checks, err = am.ListPostureChecks(context.Background(), accountID, adminUserID) assert.NoError(t, err) assert.Len(t, checks, 0) }) } -func initTestPostureChecksAccount(am *DefaultAccountManager) (*Account, error) { +func initTestPostureChecksAccount(am *DefaultAccountManager) (string, error) { accountID := "testingAccount" domain := "example.com" - admin := &User{ - Id: adminUserID, - Role: UserRoleAdmin, - } - user := &User{ - Id: regularUserID, - Role: UserRoleUser, + err := newAccountWithId(context.Background(), am.Store, accountID, groupAdminUserID, domain) + if err != nil { + return "", err } - account := newAccountWithId(context.Background(), accountID, groupAdminUserID, domain) - account.Users[admin.Id] = admin - account.Users[user.Id] = user - - err := am.Store.SaveAccount(context.Background(), account) + err = am.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: adminUserID, + AccountID: accountID, + Role: UserRoleAdmin, + }, + { + Id: regularUserID, + AccountID: accountID, + Role: UserRoleUser, + }, + }) if err != nil { - return nil, err + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return accountID, nil } func TestPostureCheckAccountPeersUpdate(t *testing.T) { @@ -440,18 +443,18 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { manager, err := createManager(t) require.NoError(t, err, "failed to create account manager") - account, err := initTestPostureChecksAccount(manager) + accountID, err := initTestPostureChecksAccount(manager) require.NoError(t, err, "failed to init testing account") groupA := &group.Group{ ID: "groupA", - AccountID: account.Id, + AccountID: accountID, Peers: []string{"peer1"}, } groupB := &group.Group{ ID: "groupB", - AccountID: account.Id, + AccountID: accountID, Peers: []string{}, } err = manager.Store.SaveGroups(context.Background(), LockingStrengthUpdate, []*group.Group{groupA, groupB}) @@ -459,26 +462,26 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { postureCheckA := &posture.Checks{ Name: "checkA", - AccountID: account.Id, + AccountID: accountID, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckA, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckA) + postureCheckA, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckA) require.NoError(t, err, "failed to save postureCheckA") postureCheckB := &posture.Checks{ Name: "checkB", - AccountID: account.Id, + AccountID: accountID, Checks: posture.ChecksDefinition{ NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.33.1"}, }, } - postureCheckB, err = manager.SavePostureChecks(context.Background(), account.Id, adminUserID, postureCheckB) + postureCheckB, err = manager.SavePostureChecks(context.Background(), accountID, adminUserID, postureCheckB) require.NoError(t, err, "failed to save postureCheckB") policy := &Policy{ - AccountID: account.Id, + AccountID: accountID, Rules: []*PolicyRule{ { Enabled: true, @@ -489,23 +492,23 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { SourcePostureChecks: []string{postureCheckA.ID}, } - policy, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + policy, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to save policy") t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) t.Run("posture check exists but is not linked to any policy", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckB.ID) require.NoError(t, err) assert.False(t, result) }) t.Run("posture check does not exist", func(t *testing.T) { - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown") + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, "unknown") require.NoError(t, err) assert.False(t, result) }) @@ -513,10 +516,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupB"} policy.Rules[0].Destinations = []string{"groupA"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) @@ -524,10 +527,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) { policy.Rules[0].Sources = []string{"groupA"} policy.Rules[0].Destinations = []string{"groupB"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.True(t, result) }) @@ -537,7 +540,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { err = manager.Store.SaveGroup(context.Background(), LockingStrengthUpdate, groupA) require.NoError(t, err, "failed to save groups") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.False(t, result) }) @@ -545,10 +548,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) { t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) { policy.Rules[0].Sources = []string{"nonExistentGroup"} policy.Rules[0].Destinations = []string{"nonExistentGroup"} - _, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy) + _, err = manager.SavePolicy(context.Background(), accountID, adminUserID, policy) require.NoError(t, err, "failed to update policy") - result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID) + result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, accountID, postureCheckA.ID) require.NoError(t, err) assert.False(t, result) }) diff --git a/management/server/route_test.go b/management/server/route_test.go index 108f791e02c..41a8a03ae1d 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -5,9 +5,11 @@ import ( "fmt" "net" "net/netip" + "strings" "testing" "time" + "github.com/netbirdio/netbird/management/server/status" "github.com/rs/xid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -427,22 +429,22 @@ func TestCreateRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Errorf("failed to init testing account: %s", err) } if testCase.createInitRoute { - groupAll, errInit := account.GetGroupAll() + groupAll, errInit := am.Store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) + + _, errInit = am.CreateRoute(context.Background(), accountID, existingNetwork, 1, nil, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{}, true, userID, false) require.NoError(t, errInit) - _, errInit = am.CreateRoute(context.Background(), account.Id, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) + _, errInit = am.CreateRoute(context.Background(), accountID, netip.Prefix{}, 3, existingDomains, "", []string{routeGroup3, routeGroup4}, "", existingRouteID, false, 1000, []string{groupAll.ID}, []string{groupAll.ID}, true, userID, false) require.NoError(t, errInit) } - outRoute, err := am.CreateRoute(context.Background(), account.Id, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) - + outRoute, err := am.CreateRoute(context.Background(), accountID, testCase.inputArgs.network, testCase.inputArgs.networkType, testCase.inputArgs.domains, testCase.inputArgs.peerKey, testCase.inputArgs.peerGroupIDs, testCase.inputArgs.description, testCase.inputArgs.netID, testCase.inputArgs.masquerade, testCase.inputArgs.metric, testCase.inputArgs.groups, testCase.inputArgs.accessControlGroups, testCase.inputArgs.enabled, userID, testCase.inputArgs.keepRoute) testCase.errFunc(t, err) if !testCase.shouldCreate { @@ -451,6 +453,7 @@ func TestCreateRoute(t *testing.T) { // assign generated ID testCase.expectedRoute.ID = outRoute.ID + testCase.expectedRoute.AccountID = accountID if !testCase.expectedRoute.IsEqual(outRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", outRoute, testCase.expectedRoute) @@ -917,14 +920,15 @@ func TestSaveRoute(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } if testCase.createInitRoute { - account.Routes["initRoute"] = &route.Route{ + initRoute := &route.Route{ ID: "initRoute", + AccountID: accountID, Network: existingNetwork, NetID: existingRouteID, NetworkType: route.IPv4Network, @@ -935,14 +939,13 @@ func TestSaveRoute(t *testing.T) { Enabled: true, Groups: []string{routeGroup1}, } + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, initRoute) + require.NoError(t, err, "failed to save init route") } - account.Routes[testCase.existingRoute.ID] = testCase.existingRoute - - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("account should be saved") - } + testCase.existingRoute.AccountID = accountID + err = am.Store.SaveRoute(context.Background(), LockingStrengthUpdate, testCase.existingRoute) + require.NoError(t, err, "failed to save existing route") var routeToSave *route.Route @@ -977,7 +980,7 @@ func TestSaveRoute(t *testing.T) { } } - err = am.SaveRoute(context.Background(), account.Id, userID, routeToSave) + err = am.SaveRoute(context.Background(), accountID, userID, routeToSave) testCase.errFunc(t, err) @@ -985,14 +988,10 @@ func TestSaveRoute(t *testing.T) { return } - account, err = am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) - } - - savedRoute, saved := account.Routes[testCase.expectedRoute.ID] - require.True(t, saved) + savedRoute, err := am.GetRoute(context.Background(), accountID, testCase.existingRoute.ID, userID) + require.NoError(t, err, "failed to get saved route") + testCase.expectedRoute.AccountID = accountID if !testCase.expectedRoute.IsEqual(savedRoute) { t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", savedRoute, testCase.expectedRoute) } @@ -1001,50 +1000,48 @@ func TestSaveRoute(t *testing.T) { } func TestDeleteRoute(t *testing.T) { - testingRoute := &route.Route{ - ID: "testingRoute", - Network: netip.MustParsePrefix("192.168.0.0/16"), - Domains: domain.List{"domain1", "domain2"}, - KeepRoute: true, - NetworkType: route.IPv4Network, - Peer: peer1Key, - Description: "super", - Masquerade: false, - Metric: 9999, - Enabled: true, - } - am, err := createRouterManager(t) if err != nil { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } - account.Routes[testingRoute.ID] = testingRoute + err = am.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ + ID: "GroupA", + AccountID: accountID, + Name: "GroupA", + }) + require.NoError(t, err, "failed to save group") - err = am.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Error("failed to save account") + testingRoute := &route.Route{ + Network: netip.MustParsePrefix("192.168.0.0/16"), + NetID: route.NetID("12345678901234567890qw"), + Groups: []string{"GroupA"}, + KeepRoute: true, + NetworkType: route.IPv4Network, + Peer: peer1ID, + Description: "super", + Masquerade: false, + Metric: 9999, + Enabled: true, } + createdRoute, err := am.CreateRoute(context.Background(), accountID, testingRoute.Network, testingRoute.NetworkType, testingRoute.Domains, peer1ID, []string{}, testingRoute.Description, testingRoute.NetID, testingRoute.Masquerade, testingRoute.Metric, testingRoute.Groups, testingRoute.AccessControlGroups, true, userID, testingRoute.KeepRoute) + require.NoError(t, err, "failed to create route") - err = am.DeleteRoute(context.Background(), account.Id, testingRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, createdRoute.ID, userID) if err != nil { t.Error("deleting route failed with error: ", err) } - savedAccount, err := am.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Error("failed to retrieve saved account with error: ", err) - } - - _, found := savedAccount.Routes[testingRoute.ID] - if found { - t.Error("route shouldn't be found after delete") - } + _, err = am.GetRoute(context.Background(), accountID, testingRoute.ID, userID) + require.NotNil(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, status.NotFound, sErr.Type()) } func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { @@ -1066,16 +1063,14 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) - if err != nil { - t.Error("failed to init testing account") - } + accountID, err := initTestRouteAccount(t, am) + require.NoError(t, err, "failed to init testing account") newAccountRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - newRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) + newRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, baseRoute.Enabled, userID, baseRoute.KeepRoute) require.NoError(t, err) require.Equal(t, newRoute.Enabled, true) @@ -1091,7 +1086,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") - groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, account.Id) + groups, err := am.Store.GetAccountGroups(context.Background(), LockingStrengthShare, accountID) require.NoError(t, err) var groupHA1, groupHA2 *nbgroup.Group for _, group := range groups { @@ -1103,21 +1098,21 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { } } - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA1.ID, peer2ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA1.ID, peer2ID) require.NoError(t, err) peer2RoutesAfterDelete, err := am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 2, "after peer deletion group should have 2 client routes") - err = am.GroupDeletePeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupDeletePeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer2RoutesAfterDelete, err = am.GetNetworkMap(context.Background(), peer2ID) require.NoError(t, err) assert.Len(t, peer2RoutesAfterDelete.Routes, 1, "after peer deletion group should have only 1 route") - err = am.GroupAddPeer(context.Background(), account.Id, groupHA2.ID, peer4ID) + err = am.GroupAddPeer(context.Background(), accountID, groupHA2.ID, peer4ID) require.NoError(t, err) peer1RoutesAfterAdd, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1128,7 +1123,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) { require.NoError(t, err) assert.Len(t, peer2RoutesAfterAdd.Routes, 2, "HA route should have 2 client routes") - err = am.DeleteRoute(context.Background(), account.Id, newRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, newRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1158,7 +1153,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { t.Error("failed to create account manager") } - account, err := initTestRouteAccount(t, am) + accountID, err := initTestRouteAccount(t, am) if err != nil { t.Error("failed to init testing account") } @@ -1167,7 +1162,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, newAccountRoutes.Routes, 0, "new accounts should have no routes") - createdRoute, err := am.CreateRoute(context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) + createdRoute, err := am.CreateRoute(context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, peer1ID, []string{}, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, baseRoute.AccessControlGroups, false, userID, baseRoute.KeepRoute) require.NoError(t, err) noDisabledRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1181,7 +1176,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { expectedRoute := enabledRoute.Copy() expectedRoute.Peer = peer1Key - err = am.SaveRoute(context.Background(), account.Id, userID, enabledRoute) + err = am.SaveRoute(context.Background(), accountID, userID, enabledRoute) require.NoError(t, err) peer1Routes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1193,7 +1188,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2Routes.Routes, 0, "no routes for peers not in the distribution group") - err = am.GroupAddPeer(context.Background(), account.Id, routeGroup1, peer2ID) + err = am.GroupAddPeer(context.Background(), accountID, routeGroup1, peer2ID) require.NoError(t, err) peer2Routes, err = am.GetNetworkMap(context.Background(), peer2ID) @@ -1206,10 +1201,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { Name: "peer1 group", Peers: []string{peer1ID}, } - err = am.SaveGroup(context.Background(), account.Id, userID, newGroup) + err = am.SaveGroup(context.Background(), accountID, userID, newGroup) require.NoError(t, err) - rules, err := am.ListPolicies(context.Background(), account.Id, "testingUser") + rules, err := am.ListPolicies(context.Background(), accountID, "testingUser") require.NoError(t, err) defaultRule := rules[0] @@ -1218,10 +1213,10 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { newPolicy.Rules[0].Sources = []string{newGroup.ID} newPolicy.Rules[0].Destinations = []string{newGroup.ID} - _, err = am.SavePolicy(context.Background(), account.Id, userID, newPolicy) + _, err = am.SavePolicy(context.Background(), accountID, userID, newPolicy) require.NoError(t, err) - err = am.DeletePolicy(context.Background(), account.Id, defaultRule.ID, userID) + err = am.DeletePolicy(context.Background(), accountID, defaultRule.ID, userID) require.NoError(t, err) peer1GroupRoutes, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1232,7 +1227,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { require.NoError(t, err) require.Len(t, peer2GroupRoutes.Routes, 0, "we should not receive routes for peer2") - err = am.DeleteRoute(context.Background(), account.Id, enabledRoute.ID, userID) + err = am.DeleteRoute(context.Background(), accountID, enabledRoute.ID, userID) require.NoError(t, err) peer1DeletedRoute, err := am.GetNetworkMap(context.Background(), peer1ID) @@ -1266,179 +1261,104 @@ func createRouterStore(t *testing.T) (Store, error) { return store, nil } -func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, error) { +func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (string, error) { t.Helper() accountID := "testingAcc" domain := "example.com" - account := newAccountWithId(context.Background(), accountID, userID, domain) - err := am.Store.SaveAccount(context.Background(), account) + err := newAccountWithId(context.Background(), am.Store, accountID, userID, domain) if err != nil { - return nil, err + return "", err } - ips := account.getTakenIPs() - peer1IP, err := AllocatePeerIP(account.Network.Net, ips) - if err != nil { - return nil, err - } + createPeer := func(peerID, peerKey, peerName, dnsLabel, kernel, core, platform, os string) (*nbpeer.Peer, error) { + ips, err := am.Store.GetTakenIPs(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } - peer1 := &nbpeer.Peer{ - IP: peer1IP, - ID: peer1ID, - Key: peer1Key, - Name: "test-host1@netbird.io", - DNSLabel: "test-host1", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host1@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer1.ID] = peer1 + network, err := am.Store.GetAccountNetwork(context.Background(), LockingStrengthShare, accountID) + if err != nil { + return nil, err + } - ips = account.getTakenIPs() - peer2IP, err := AllocatePeerIP(account.Network.Net, ips) - if err != nil { - return nil, err - } + peerIP, err := AllocatePeerIP(network.Net, ips) + if err != nil { + return nil, err + } - peer2 := &nbpeer.Peer{ - IP: peer2IP, - ID: peer2ID, - Key: peer2Key, - Name: "test-host2@netbird.io", - DNSLabel: "test-host2", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host2@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, + peer := &nbpeer.Peer{ + IP: peerIP, + AccountID: accountID, + ID: peerID, + Key: peerKey, + Name: peerName, + DNSLabel: dnsLabel, + UserID: userID, + Meta: nbpeer.PeerSystemMeta{ + Hostname: peerName, + GoOS: strings.ToLower(kernel), + Kernel: kernel, + Core: core, + Platform: platform, + OS: os, + WtVersion: "development", + UIVersion: "development", + }, + Status: &nbpeer.PeerStatus{}, + } + if err := am.Store.AddPeerToAccount(context.Background(), peer); err != nil { + return nil, err + } + return peer, nil } - account.Peers[peer2.ID] = peer2 - ips = account.getTakenIPs() - peer3IP, err := AllocatePeerIP(account.Network.Net, ips) + // Create peers + peer1, err := createPeer(peer1ID, peer1Key, "test-host1@netbird.io", "test-host1", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer3 := &nbpeer.Peer{ - IP: peer3IP, - ID: peer3ID, - Key: peer3Key, - Name: "test-host3@netbird.io", - DNSLabel: "test-host3", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host3@netbird.io", - GoOS: "darwin", - Kernel: "Darwin", - Core: "13.4.1", - Platform: "arm64", - OS: "darwin", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer3.ID] = peer3 - - ips = account.getTakenIPs() - peer4IP, err := AllocatePeerIP(account.Network.Net, ips) + peer2, err := createPeer(peer2ID, peer2Key, "test-host2@netbird.io", "test-host2", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - - peer4 := &nbpeer.Peer{ - IP: peer4IP, - ID: peer4ID, - Key: peer4Key, - Name: "test-host4@netbird.io", - DNSLabel: "test-host4", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host4@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, - } - account.Peers[peer4.ID] = peer4 - - ips = account.getTakenIPs() - peer5IP, err := AllocatePeerIP(account.Network.Net, ips) + peer3, err := createPeer(peer3ID, peer3Key, "test-host3@netbird.io", "test-host3", "Darwin", "13.4.1", "arm64", "darwin") if err != nil { - return nil, err + return "", err } - - peer5 := &nbpeer.Peer{ - IP: peer5IP, - ID: peer5ID, - Key: peer5Key, - Name: "test-host5@netbird.io", - DNSLabel: "test-host5", - UserID: userID, - Meta: nbpeer.PeerSystemMeta{ - Hostname: "test-host5@netbird.io", - GoOS: "linux", - Kernel: "Linux", - Core: "21.04", - Platform: "x86_64", - OS: "Ubuntu", - WtVersion: "development", - UIVersion: "development", - }, - Status: &nbpeer.PeerStatus{}, + peer4, err := createPeer(peer4ID, peer4Key, "test-host4@netbird.io", "test-host4", "Linux", "21.04", "x86_64", "Ubuntu") + if err != nil { + return "", err } - account.Peers[peer5.ID] = peer5 - - err = am.Store.SaveAccount(context.Background(), account) + peer5, err := createPeer(peer5ID, peer5Key, "test-host5@netbird.io", "test-host5", "Linux", "21.04", "x86_64", "Ubuntu") if err != nil { - return nil, err + return "", err } - groupAll, err := account.GetGroupAll() + + groupAll, err := am.GetGroupByName(context.Background(), "All", accountID) if err != nil { - return nil, err + return "", err } + err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer1ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer2ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer3ID) if err != nil { - return nil, err + return "", err } err = am.GroupAddPeer(context.Background(), accountID, groupAll.ID, peer4ID) if err != nil { - return nil, err + return "", err } - newGroup := []*nbgroup.Group{ + newGroups := []*nbgroup.Group{ { ID: routeGroup1, Name: routeGroup1, @@ -1470,15 +1390,12 @@ func initTestRouteAccount(t *testing.T, am *DefaultAccountManager) (*Account, er Peers: []string{peer1.ID, peer4.ID}, }, } - - for _, group := range newGroup { - err = am.SaveGroup(context.Background(), accountID, userID, group) - if err != nil { - return nil, err - } + err = am.SaveGroups(context.Background(), accountID, userID, newGroups) + if err != nil { + return "", err } - return am.Store.GetAccount(context.Background(), account.Id) + return accountID, nil } func TestAccount_getPeersRoutesFirewall(t *testing.T) { @@ -1782,10 +1699,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) { manager, err := createRouterManager(t) require.NoError(t, err, "failed to create account manager") - account, err := initTestRouteAccount(t, manager) + accountID, err := initTestRouteAccount(t, manager) require.NoError(t, err, "failed to init testing account") - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "groupA", Name: "GroupA", @@ -1831,7 +1748,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1867,7 +1784,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() _, err := manager.CreateRoute( - context.Background(), account.Id, route.Network, route.NetworkType, route.Domains, route.Peer, + context.Background(), accountID, route.Network, route.NetworkType, route.Domains, route.Peer, route.PeerGroups, route.Description, route.NetID, route.Masquerade, route.Metric, route.Groups, []string{}, true, userID, route.KeepRoute, ) @@ -1903,7 +1820,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { }() newRoute, err := manager.CreateRoute( - context.Background(), account.Id, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, + context.Background(), accountID, baseRoute.Network, baseRoute.NetworkType, baseRoute.Domains, baseRoute.Peer, baseRoute.PeerGroups, baseRoute.Description, baseRoute.NetID, baseRoute.Masquerade, baseRoute.Metric, baseRoute.Groups, []string{}, true, userID, baseRoute.KeepRoute, ) @@ -1927,7 +1844,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute) + err := manager.SaveRoute(context.Background(), accountID, userID, &baseRoute) require.NoError(t, err) select { @@ -1945,7 +1862,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err := manager.DeleteRoute(context.Background(), account.Id, baseRoute.ID, userID) + err := manager.DeleteRoute(context.Background(), accountID, baseRoute.ID, userID) require.NoError(t, err) select { @@ -1969,7 +1886,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{routeGroup1}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -1981,7 +1898,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupB", Name: "GroupB", Peers: []string{peer1ID}, @@ -2009,7 +1926,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { Groups: []string{"groupC"}, } _, err := manager.CreateRoute( - context.Background(), account.Id, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, + context.Background(), accountID, newRoute.Network, newRoute.NetworkType, newRoute.Domains, newRoute.Peer, newRoute.PeerGroups, newRoute.Description, newRoute.NetID, newRoute.Masquerade, newRoute.Metric, newRoute.Groups, []string{}, true, userID, newRoute.KeepRoute, ) @@ -2021,7 +1938,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { close(done) }() - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "groupC", Name: "GroupC", Peers: []string{peer1ID}, diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index ea239ec0c63..4ef765a51b7 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -25,12 +25,12 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{ + err = manager.SaveGroups(context.Background(), accountID, userID, []*nbgroup.Group{ { ID: "group_1", Name: "group_name_1", @@ -49,7 +49,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { expiresIn := time.Hour keyName := "my-test-key" - key, err := manager.CreateSetupKey(context.Background(), account.Id, keyName, SetupKeyReusable, expiresIn, []string{}, + key, err := manager.CreateSetupKey(context.Background(), accountID, keyName, SetupKeyReusable, expiresIn, []string{}, SetupKeyUnlimitedUsage, userID, false) if err != nil { t.Fatal(err) @@ -58,7 +58,7 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { autoGroups := []string{"group_1", "group_2"} newKeyName := "my-new-test-key" revoked := true - newKey, err := manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + newKey, err := manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -72,22 +72,22 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) { key.Id, time.Now().UTC(), autoGroups, true) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyRevoked) + ev := getEvent(t, accountID, manager, activity.SetupKeyRevoked) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, newKeyName, ev.Meta["name"]) assert.Equal(t, fmt.Sprint(key.Type), fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) assert.Equal(t, userID, ev.InitiatorID) assert.Equal(t, key.Id, ev.TargetID) - groupAll, err := account.GetGroupAll() + groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID) assert.NoError(t, err) // saving setup key with All group assigned to auto groups should return error autoGroups = append(autoGroups, groupAll.ID) - _, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{ + _, err = manager.SaveSetupKey(context.Background(), accountID, &SetupKey{ Id: key.Id, Name: newKeyName, Revoked: revoked, @@ -103,12 +103,12 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -117,7 +117,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, @@ -126,7 +126,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { t.Fatal(err) } - groupAll, err := account.GetGroupAll() + groupAll, err := manager.GetGroupByName(context.Background(), "All", accountID) assert.NoError(t, err) type testCase struct { @@ -170,7 +170,7 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { for _, tCase := range []testCase{testCase1, testCase2, testCase3} { t.Run(tCase.name, func(t *testing.T) { - key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn, + key, err := manager.CreateSetupKey(context.Background(), accountID, tCase.expectedKeyName, SetupKeyReusable, expiresIn, tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false) if tCase.expectedFailure { @@ -189,10 +189,10 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) { tCase.expectedUpdatedAt, tCase.expectedGroups, false) // check the corresponding events that should have been generated - ev := getEvent(t, account.Id, manager, activity.SetupKeyCreated) + ev := getEvent(t, accountID, manager, activity.SetupKeyCreated) assert.NotNil(t, ev) - assert.Equal(t, account.Id, ev.AccountID) + assert.Equal(t, accountID, ev.AccountID) assert.Equal(t, tCase.expectedKeyName, ev.Meta["name"]) assert.Equal(t, tCase.expectedType, fmt.Sprint(ev.Meta["type"])) assert.NotEmpty(t, ev.Meta["key"]) @@ -208,12 +208,12 @@ func TestGetSetupKeys(t *testing.T) { } userID := "testingUser" - account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), userID, "") if err != nil { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_1", Name: "group_name_1", Peers: []string{}, @@ -222,7 +222,7 @@ func TestGetSetupKeys(t *testing.T) { t.Fatal(err) } - err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + err = manager.SaveGroup(context.Background(), accountID, userID, &nbgroup.Group{ ID: "group_2", Name: "group_name_2", Peers: []string{}, diff --git a/management/server/sql_store.go b/management/server/sql_store.go index c9fc51c9e70..7f09f5dba6a 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -261,7 +261,7 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *Account) error { return result.Error } - result = tx.Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) + result = tx.Debug().Select(clause.Associations).Delete(account.UsersG, "account_id = ?", account.Id) if result.Error != nil { return result.Error } @@ -332,24 +332,27 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a return nil } -func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error { accountCopy := Account{ - Domain: domain, - DomainCategory: category, - IsDomainPrimaryAccount: isPrimaryDomain, + Domain: domain, + DomainCategory: category, } - fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.Model(&Account{}). - Select(fieldsToUpdate). - Where(idQueryCondition, accountID). - Updates(&accountCopy) + fieldsToUpdate := []string{"domain", "domain_category"} + if isPrimaryDomain != nil { + accountCopy.IsDomainPrimaryAccount = *isPrimaryDomain + fieldsToUpdate = append(fieldsToUpdate, "is_domain_primary_account") + } + + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select(fieldsToUpdate). + Where(idQueryCondition, accountID).Updates(&accountCopy) if result.Error != nil { - return result.Error + log.WithContext(ctx).Errorf("failed to update account domain attributes in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to update account domain attributes in store") } if result.RowsAffected == 0 { - return status.Errorf(status.NotFound, "account %s", accountID) + return status.NewAccountNotFoundError(accountID) } return nil @@ -437,16 +440,6 @@ func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, return nil } -// DeleteHashedPAT2TokenIDIndex is noop in SqlStore -func (s *SqlStore) DeleteHashedPAT2TokenIDIndex(hashedToken string) error { - return nil -} - -// DeleteTokenID2UserIDIndex is noop in SqlStore -func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error { - return nil -} - func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error) { accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) if err != nil { @@ -877,6 +870,17 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking return createdBy, nil } +func (s *SqlStore) GetTotalAccounts(ctx context.Context) (int64, error) { + var count int64 + result := s.db.Model(&Account{}).Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get total accounts from store: %s", result.Error) + return 0, status.Errorf(status.Internal, "failed to get total accounts from store") + } + + return count, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User @@ -1445,17 +1449,21 @@ func (s *SqlStore) SavePolicy(ctx context.Context, lockStrength LockingStrength, } func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) error { - result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). - Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID) - if err := result.Error; err != nil { + err := s.db.Transaction(func(tx *gorm.DB) error { + result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&PolicyRule{}, "policy_id = ?", policyID) + if result.Error != nil { + return result.Error + } + + return tx.Clauses(clause.Locking{Strength: string(lockStrength)}). + Delete(&Policy{}, accountAndIDQueryCondition, accountID, policyID).Error + }) + if err != nil { log.WithContext(ctx).Errorf("failed to delete policy from store: %s", err) return status.Errorf(status.Internal, "failed to delete policy from store") } - if result.RowsAffected == 0 { - return status.NewPolicyNotFoundError(policyID) - } - return nil } @@ -1711,6 +1719,31 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } +// SaveAccountSettings stores the account settings in DB. +func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). + Select("*").Where(idQueryCondition, accountID).Updates(&AccountSettings{Settings: settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save account settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} + +func (s *SqlStore) CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Create(&account) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save new account in store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save new account in store") + } + return nil +} + // GetPATByHashedToken returns a PersonalAccessToken by its hashed token. func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*PersonalAccessToken, error) { var pat PersonalAccessToken diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 0eeb806dbce..0547c16bc41 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/account" nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/posture" "github.com/stretchr/testify/assert" @@ -68,20 +69,27 @@ func TestSqlite_SaveAccount_Large(t *testing.T) { func runLargeTest(t *testing.T, store Store) { t.Helper() - account := newAccountWithId(context.Background(), "account_id", "testuser", "") - groupALL, err := account.GetGroupAll() - if err != nil { - t.Fatal(err) - } + accountID := "account_id" + + err := newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") + + groupAll, err := store.GetGroupByName(context.Background(), LockingStrengthShare, accountID, "All") + require.NoError(t, err, "failed to get group All") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + assert.NoError(t, err, "failed to save setup key") + const numPerAccount = 6000 for n := 0; n < numPerAccount; n++ { netIP := randomIPv4() - peerID := fmt.Sprintf("%s-peer-%d", account.Id, n) + peerID := fmt.Sprintf("%s-peer-%d", accountID, n) peer := &nbpeer.Peer{ ID: peerID, + AccountID: accountID, Key: peerID, IP: netIP, Name: peerID, @@ -90,16 +98,21 @@ func runLargeTest(t *testing.T, store Store) { Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, SSHEnabled: false, } - account.Peers[peerID] = peer - group, _ := account.GetGroupAll() - group.Peers = append(group.Peers, peerID) - user := &User{ - Id: fmt.Sprintf("%s-user-%d", account.Id, n), - AccountID: account.Id, - } - account.Users[user.Id] = user + err = store.AddPeerToAccount(context.Background(), peer) + require.NoError(t, err, "failed to add peer") + + err = store.AddPeerToAllGroup(context.Background(), accountID, peerID) + require.NoError(t, err, "failed to add peer to all group") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: fmt.Sprintf("%s-user-%d", accountID, n), + AccountID: accountID, + }) + require.NoError(t, err, "failed to save user") + route := &route2.Route{ ID: route2.ID(fmt.Sprintf("network-id-%d", n)), + AccountID: accountID, Description: "base route", NetID: route2.NetID(fmt.Sprintf("network-id-%d", n)), Network: netip.MustParsePrefix(netIP.String() + "/24"), @@ -107,22 +120,24 @@ func runLargeTest(t *testing.T, store Store) { Metric: 9999, Masquerade: false, Enabled: true, - Groups: []string{groupALL.ID}, + Groups: []string{groupAll.ID}, } - account.Routes[route.ID] = route + err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route) + require.NoError(t, err, "failed to save route") - group = &nbgroup.Group{ + group := &nbgroup.Group{ ID: fmt.Sprintf("group-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("group-id-%d", n), Issued: "api", Peers: nil, } - account.Groups[group.ID] = group + err = store.SaveGroup(context.Background(), LockingStrengthUpdate, group) + require.NoError(t, err, "failed to save group") nameserver := &nbdns.NameServerGroup{ ID: fmt.Sprintf("nameserver-id-%d", n), - AccountID: account.Id, + AccountID: accountID, Name: fmt.Sprintf("nameserver-id-%d", n), Description: "", NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr(netIP.String()), NSType: nbdns.UDPNameServerType}}, @@ -132,20 +147,20 @@ func runLargeTest(t *testing.T, store Store) { Enabled: false, SearchDomainsEnabled: false, } - account.NameServerGroups[nameserver.ID] = nameserver + err = store.SaveNameServerGroup(context.Background(), LockingStrengthUpdate, nameserver) + require.NoError(t, err, "failed to save nameserver group") - setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey + setupKey, _ = GenerateDefaultSetupKey() + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") } - err = store.SaveAccount(context.Background(), account) + totalAccounts, err := store.GetTotalAccounts(context.Background()) require.NoError(t, err) + require.Equal(t, int64(1), totalAccounts, "expected 1 account") - if len(store.GetAllAccounts(context.Background())) != 1 { - t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") - } - - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -213,41 +228,53 @@ func TestSqlite_SaveAccount(t *testing.T) { t.Setenv("NETBIRD_STORE_ENGINE", string(SqliteStoreEngine)) store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir()) t.Cleanup(cleanUp) - assert.NoError(t, err) + require.NoError(t, err) + + accountID := "account_id" + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") - account := newAccountWithId(context.Background(), "account_id", "testuser", "") setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + accountID2 := "account_id2" + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + setupKey, _ = GenerateDefaultSetupKey() + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") if len(store.GetAllAccounts(context.Background())) != 2 { t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") } - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -295,7 +322,12 @@ func TestSqlite_DeleteAccount(t *testing.T) { Name: "test token", }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, "account_id", testUserID, "") + require.NoError(t, err, "failed to create account") + + account, err := store.GetAccount(context.Background(), "account_id") + require.NoError(t, err, "failed to get account") + setupKey, _ := GenerateDefaultSetupKey() account.SetupKeys[setupKey.Key] = setupKey account.Peers["testpeer"] = &nbpeer.Peer{ @@ -685,19 +717,29 @@ func newSqliteStore(t *testing.T) *SqlStore { } func newAccount(store Store, id int) error { - str := fmt.Sprintf("%s-%d", uuid.New().String(), id) - account := newAccountWithId(context.Background(), str, str+"-testuser", "example.com") + accountID := fmt.Sprintf("%s-%d", uuid.New().String(), id) + userID := accountID + "-testuser" + + err := newAccountWithId(context.Background(), store, accountID, userID, "example.com") + if err != nil { + return err + } + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["p"+str] = &nbpeer.Peer{ - Key: "peerkey" + str, + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + if err != nil { + return err + } + + return store.SavePeer(context.Background(), LockingStrengthUpdate, accountID, &nbpeer.Peer{ + ID: "p" + accountID, + Key: accountID + "-peerkey", IP: net.IP{127, 0, 0, 1}, Meta: nbpeer.PeerSystemMeta{}, Name: "peer name", Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - - return store.SaveAccount(context.Background(), account) + }) } func TestPostgresql_NewStore(t *testing.T) { @@ -725,39 +767,53 @@ func TestPostgresql_SaveAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) - account := newAccountWithId(context.Background(), "account_id", "testuser", "") + accountID := "account_id" + + err = newAccountWithId(context.Background(), store, accountID, "testuser", "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer", + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + AccountID: accountID, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") - account2 := newAccountWithId(context.Background(), "account_id2", "testuser2", "") - setupKey, _ = GenerateDefaultSetupKey() - account2.SetupKeys[setupKey.Key] = setupKey - account2.Peers["testpeer2"] = &nbpeer.Peer{ - Key: "peerkey2", - IP: net.IP{127, 0, 0, 2}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name 2", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } + accountID2 := "account_id2" - err = store.SaveAccount(context.Background(), account2) - require.NoError(t, err) + err = newAccountWithId(context.Background(), store, accountID2, "testuser2", "") + require.NoError(t, err, "failed to create account") - if len(store.GetAllAccounts(context.Background())) != 2 { - t.Errorf("expecting 2 Accounts to be stored after SaveAccount()") - } + setupKey, _ = GenerateDefaultSetupKey() + setupKey.AccountID = accountID2 + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") + + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testpeer2", + Key: "peerkey2", + AccountID: accountID2, + IP: net.IP{127, 0, 0, 2}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name 2", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") + + totalAccounts, err := store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(2), totalAccounts, "expecting 2 Accounts to be stored after SaveAccount()") - a, err := store.GetAccount(context.Background(), account.Id) + a, err := store.GetAccount(context.Background(), accountID) if a == nil { t.Errorf("expecting Account to be stored after SaveAccount(): %v", err) } @@ -798,31 +854,41 @@ func TestPostgresql_DeleteAccount(t *testing.T) { t.Cleanup(cleanUp) assert.NoError(t, err) + accountID := "account_id" testUserID := "testuser" - user := NewAdminUser(testUserID) - user.PATs = map[string]*PersonalAccessToken{"testtoken": { - ID: "testtoken", - Name: "test token", - }} - account := newAccountWithId(context.Background(), "account_id", testUserID, "") + err = newAccountWithId(context.Background(), store, accountID, testUserID, "") + require.NoError(t, err, "failed to create account") + setupKey, _ := GenerateDefaultSetupKey() - account.SetupKeys[setupKey.Key] = setupKey - account.Peers["testpeer"] = &nbpeer.Peer{ - Key: "peerkey", - IP: net.IP{127, 0, 0, 1}, - Meta: nbpeer.PeerSystemMeta{}, - Name: "peer name", - Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, - } - account.Users[testUserID] = user + setupKey.AccountID = accountID + err = store.SaveSetupKey(context.Background(), LockingStrengthUpdate, setupKey) + require.NoError(t, err, "failed to save setup key") - err = store.SaveAccount(context.Background(), account) - require.NoError(t, err) + err = store.AddPeerToAccount(context.Background(), &nbpeer.Peer{ + ID: "testingpeer", + AccountID: accountID, + Key: "peerkey", + IP: net.IP{127, 0, 0, 1}, + Meta: nbpeer.PeerSystemMeta{}, + Name: "peer name", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}, + }) + require.NoError(t, err, "failed to save peer") - if len(store.GetAllAccounts(context.Background())) != 1 { - t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") - } + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: "testtoken", + UserID: testUserID, + Name: "test token", + }) + require.NoError(t, err, "failed to save personal access token") + + totalAccounts, err := store.GetTotalAccounts(context.Background()) + require.NoError(t, err, "failed to get total accounts") + require.Equal(t, int64(1), totalAccounts, "expecting 1 Accounts to be stored after SaveAccount()") + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err, "failed to get account") err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) @@ -1172,7 +1238,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "example.com" category := "public" IsDomainPrimaryAccount := false - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1186,7 +1252,7 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), accountID, domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, &IsDomainPrimaryAccount) require.NoError(t, err) account, err := store.GetAccount(context.Background(), accountID) require.NoError(t, err) @@ -1200,10 +1266,23 @@ func TestSqlStore_UpdateAccountDomainAttributes(t *testing.T) { domain := "test.com" category := "private" IsDomainPrimaryAccount := true - err = store.UpdateAccountDomainAttributes(context.Background(), "non-existing-account-id", domain, category, IsDomainPrimaryAccount) + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, "non-existing-account-id", domain, category, &IsDomainPrimaryAccount) require.Error(t, err) }) + t.Run("Should update domain and category but skip primary account when isPrimary is nil", func(t *testing.T) { + domain := "test.com" + category := "private" + err = store.UpdateAccountDomainAttributes(context.Background(), LockingStrengthUpdate, accountID, domain, category, nil) + require.NoError(t, err) + + account, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, domain, account.Domain) + require.Equal(t, category, account.DomainCategory) + require.True(t, account.IsDomainPrimaryAccount) + }) + } func TestSqlite_GetGroupByName(t *testing.T) { @@ -2589,3 +2668,157 @@ func TestSqlStore_DeleteRoute(t *testing.T) { require.Error(t, err) require.Nil(t, route) } + +func TestSqlStore_GetAccountSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + expectError bool + }{ + { + name: "retrieve existing account settings", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectError: false, + }, + { + name: "retrieve non-existing account settings", + accountID: "non-existing", + expectError: true, + }, + { + name: "retrieve account settings with empty account ID", + accountID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, tt.accountID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, settings) + } else { + require.NoError(t, err) + require.False(t, settings.RegularUsersViewBlocked) + require.False(t, settings.JWTGroupsEnabled) + require.False(t, settings.GroupsPropagationEnabled) + require.False(t, settings.PeerInactivityExpirationEnabled) + require.False(t, settings.PeerLoginExpirationEnabled) + require.False(t, settings.Extra.PeerApprovalEnabled) + require.Equal(t, time.Duration(86400000000000), settings.PeerLoginExpiration) + require.Equal(t, time.Duration(0), settings.PeerInactivityExpiration) + require.Len(t, settings.Extra.IntegratedValidatorGroups, 0) + } + }) + } +} + +func TestSqlStore_SaveAccountSettings(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + + settings.Extra.IntegratedValidatorGroups = []string{"groupA"} + settings.RegularUsersViewBlocked = true + settings.PeerInactivityExpiration = 30 * time.Minute + err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, accountID, settings) + require.NoError(t, err) + + saveSettings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, accountID) + require.NoError(t, err) + require.Equal(t, settings, saveSettings) +} + +func TestSqlStore_GetTotalAccounts(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + totalAccounts, err := store.GetTotalAccounts(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(1), totalAccounts) +} + +func TestSqlStore_AccountExists(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + t.Run("existing account", func(t *testing.T) { + exists, err := store.AccountExists(context.Background(), LockingStrengthShare, "bf1c8084-ba50-4ce7-9439-34653001fc3b") + require.NoError(t, err) + require.True(t, exists) + }) + + t.Run("non-existing account", func(t *testing.T) { + exists, err := store.AccountExists(context.Background(), LockingStrengthShare, "non-existing") + require.NoError(t, err) + require.False(t, exists) + }) +} + +func TestSqlStore_CreateAccount(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + acc := &Account{ + Id: "test-account", + CreatedAt: time.Now().UTC(), + Network: NewNetwork(), + CreatedBy: userID, + Domain: "test-domain", + DNSSettings: DNSSettings{ + DisabledManagementGroups: make([]string, 0), + }, + Settings: &Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: DefaultPeerLoginExpiration, + GroupsPropagationEnabled: true, + RegularUsersViewBlocked: true, + PeerInactivityExpirationEnabled: false, + PeerInactivityExpiration: DefaultPeerInactivityExpiration, + Extra: &account.ExtraSettings{ + PeerApprovalEnabled: false, + IntegratedValidatorGroups: make([]string, 0), + }, + }, + } + err = store.CreateAccount(context.Background(), LockingStrengthUpdate, acc) + require.NoError(t, err) + + storeAccount, err := store.GetAccount(context.Background(), acc.Id) + require.NoError(t, err) + require.NotNil(t, storeAccount) + + require.Equal(t, acc.Id, storeAccount.Id) + require.Equal(t, acc.CreatedBy, storeAccount.CreatedBy) + require.Equal(t, acc.Domain, storeAccount.Domain) + require.WithinDuration(t, acc.CreatedAt, storeAccount.CreatedAt, time.Second) + require.Equal(t, acc.DNSSettings.DisabledManagementGroups, storeAccount.DNSSettings.DisabledManagementGroups) + + require.NotNil(t, storeAccount.Settings) + require.Equal(t, acc.Settings.PeerLoginExpirationEnabled, storeAccount.Settings.PeerLoginExpirationEnabled) + require.Equal(t, acc.Settings.PeerLoginExpiration, storeAccount.Settings.PeerLoginExpiration) + require.Equal(t, acc.Settings.GroupsPropagationEnabled, storeAccount.Settings.GroupsPropagationEnabled) + require.Equal(t, acc.Settings.RegularUsersViewBlocked, storeAccount.Settings.RegularUsersViewBlocked) + require.Equal(t, acc.Settings.PeerInactivityExpirationEnabled, storeAccount.Settings.PeerInactivityExpirationEnabled) + require.Equal(t, acc.Settings.PeerInactivityExpiration, storeAccount.Settings.PeerInactivityExpiration) + + require.NotNil(t, storeAccount.Settings.Extra) + require.Equal(t, acc.Settings.Extra.PeerApprovalEnabled, storeAccount.Settings.Extra.PeerApprovalEnabled) + require.Equal(t, acc.Settings.Extra.IntegratedValidatorGroups, storeAccount.Settings.Extra.IntegratedValidatorGroups) +} diff --git a/management/server/store.go b/management/server/store.go index 440caad75dc..7ecad3208a9 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,10 +58,13 @@ type Store interface { GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) + GetTotalAccounts(ctx context.Context) (int64, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error - UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain *bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *DNSSettings) error + SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error + CreateAccount(ctx context.Context, lockStrength LockingStrength, account *Account) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) @@ -71,8 +74,6 @@ type Store interface { SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error DeleteUser(ctx context.Context, lockStrength LockingStrength, accountID, userID string) error GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) - DeleteHashedPAT2TokenIDIndex(hashedToken string) error - DeleteTokenID2UserIDIndex(tokenID string) error GetPATByID(ctx context.Context, lockStrength LockingStrength, userID, patID string) (*PersonalAccessToken, error) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*PersonalAccessToken, error) @@ -83,7 +84,7 @@ type Store interface { GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbgroup.Group, error) GetGroupByID(ctx context.Context, lockStrength LockingStrength, groupID, accountID string) (*nbgroup.Group, error) - GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*nbgroup.Group, error) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*nbgroup.Group, error) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error SaveGroup(ctx context.Context, lockStrength LockingStrength, group *nbgroup.Group) error diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index 54b946b5ab7..64e47ff69fa 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -33,4 +33,7 @@ INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,'0001-01-01 00:00:00+00:00','2024-10-02 17:00:32.528196+02:00','api',0,''); +INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','["cfvprsrlo1hqoo49ohog", "cg3161rlo1hs9cq94gdg", "cg05lnblo1hkg2j514p0"]',0,''); +INSERT INTO policies VALUES('cs1tnh0hhcjnqoiuebf0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Default','This is a default rule that allows connections between all the resources',1,'[]'); +INSERT INTO policy_rules VALUES('cs387mkv2d4bgq41b6n0','cs1tnh0hhcjnqoiuebf0','Default','This is a default rule that allows connections between all the resources',1,'accept','["cs1tnh0hhcjnqoiuebeg"]','["cs1tnh0hhcjnqoiuebeg"]',1,'all',NULL,NULL); INSERT INTO installations VALUES(1,''); diff --git a/management/server/user.go b/management/server/user.go index 1639ec50f21..ac4db48c536 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -953,8 +953,9 @@ func validateUserUpdate(groupsMap map[string]*nbgroup.Group, initiatorUser, oldU return nil } -// GetOrCreateAccountByUser returns an existing account for a given user id or creates a new one if doesn't exist -func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, userID, domain string) (*Account, error) { +// GetOrCreateAccountIDByUser returns the account ID for a given user ID. +// If no account exists for the user, it creates a new one using the specified domain. +func (am *DefaultAccountManager) GetOrCreateAccountIDByUser(ctx context.Context, userID, domain string) (string, error) { start := time.Now() unlock := am.Store.AcquireGlobalLock(ctx) defer unlock() @@ -962,34 +963,39 @@ func (am *DefaultAccountManager) GetOrCreateAccountByUser(ctx context.Context, u lowerDomain := strings.ToLower(domain) - account, err := am.Store.GetAccountByUser(ctx, userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { - account, err = am.newAccount(ctx, userID, lowerDomain) + accountID, err = am.newAccount(ctx, userID, lowerDomain) if err != nil { - return nil, err - } - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err + return "", err } + return accountID, nil } else { // other error - return nil, err + return "", err } } - userObj := account.Users[userID] + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthShare, userID) + if err != nil { + return err + } - if lowerDomain != "" && account.Domain != lowerDomain && userObj.Role == UserRoleOwner { - account.Domain = lowerDomain - err = am.Store.SaveAccount(ctx, account) + accDomain, accCategory, err := transaction.GetAccountDomainAndCategory(ctx, LockingStrengthUpdate, accountID) if err != nil { - return nil, status.Errorf(status.Internal, "failed updating account with domain") + return err } - } - return account, nil + if lowerDomain != "" && accDomain != lowerDomain && user.Role == UserRoleOwner { + return transaction.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, lowerDomain, accCategory, nil) + } + + return nil + }) + + return accountID, err } // GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return diff --git a/management/server/user_test.go b/management/server/user_test.go index 2f8c1bf705b..cd43aab319e 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -43,37 +43,34 @@ const ( func TestUser_CreatePAT_ForSameUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, eventStore: &activity.InMemoryEventStore{}, } - pat, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) + newPAT, err := am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockTokenName, mockExpiresIn) if err != nil { t.Fatalf("Error when adding PAT to user: %s", err) } - assert.Equal(t, pat.CreatedBy, mockUserID) + assert.Equal(t, newPAT.CreatedBy, mockUserID) - tokenID, err := am.Store.GetTokenIDByHashedToken(context.Background(), pat.HashedToken) + pat, err := am.Store.GetPATByHashedToken(context.Background(), LockingStrengthShare, newPAT.HashedToken) if err != nil { t.Fatalf("Error when getting token ID by hashed token: %s", err) } - if tokenID == "" { + if pat.ID == "" { t.Fatal("GetTokenIDByHashedToken failed after adding PAT") } - assert.Equal(t, pat.ID, tokenID) + assert.Equal(t, newPAT.ID, pat.ID) - user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, tokenID) + user, err := am.Store.GetUserByPATID(context.Background(), LockingStrengthShare, pat.ID) if err != nil { t.Fatalf("Error when getting user by token ID: %s", err) } @@ -84,15 +81,16 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) { func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: false, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -106,15 +104,16 @@ func TestUser_CreatePAT_ForDifferentUser(t *testing.T) { func TestUser_CreatePAT_ForServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockTargetUserId] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockTargetUserId, + AccountID: mockAccountID, IsServiceUser: true, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -132,12 +131,9 @@ func TestUser_CreatePAT_ForServiceUser(t *testing.T) { func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -151,12 +147,9 @@ func TestUser_CreatePAT_WithWrongExpiration(t *testing.T) { func TestUser_CreatePAT_WithEmptyName(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -164,26 +157,22 @@ func TestUser_CreatePAT_WithEmptyName(t *testing.T) { } _, err = am.CreatePAT(context.Background(), mockAccountID, mockUserID, mockUserID, mockEmptyTokenName, mockExpiresIn) - assert.Errorf(t, err, "Wrong expiration should thorw error") + assert.Errorf(t, err, "Wrong expiration should throw error") } func TestUser_DeletePAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -195,7 +184,7 @@ func TestUser_DeletePAT(t *testing.T) { t.Fatalf("Error when adding PAT to user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) if err != nil { t.Fatalf("Error when getting account: %s", err) } @@ -206,21 +195,16 @@ func TestUser_DeletePAT(t *testing.T) { func TestUser_GetPAT(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -239,25 +223,23 @@ func TestUser_GetPAT(t *testing.T) { func TestUser_GetAllPATs(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockUserID] = &User{ - Id: mockUserID, - AccountID: mockAccountID, - PATs: map[string]*PersonalAccessToken{ - mockTokenID1: { - ID: mockTokenID1, - HashedToken: mockToken1, - }, - mockTokenID2: { - ID: mockTokenID2, - HashedToken: mockToken2, - }, - }, - } - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID1, + UserID: mockUserID, + HashedToken: mockToken1, + }) + assert.NoError(t, err, "failed to create PAT") + + err = store.SavePAT(context.Background(), LockingStrengthUpdate, &PersonalAccessToken{ + ID: mockTokenID2, + UserID: mockUserID, + HashedToken: mockToken2, + }) + assert.NoError(t, err, "failed to create PAT") am := DefaultAccountManager{ Store: store, @@ -342,12 +324,9 @@ func validateStruct(s interface{}) (err error) { func TestUser_CreateServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -359,7 +338,7 @@ func TestUser_CreateServiceUser(t *testing.T) { t.Fatalf("Error when creating service user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.Equal(t, 2, len(account.Users)) @@ -383,12 +362,9 @@ func TestUser_CreateServiceUser(t *testing.T) { func TestUser_CreateUser_ServiceUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -406,7 +382,7 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { t.Fatalf("Error when creating user: %s", err) } - account, err = store.GetAccount(context.Background(), mockAccountID) + account, err := store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) assert.True(t, user.IsServiceUser) @@ -425,12 +401,9 @@ func TestUser_CreateUser_ServiceUser(t *testing.T) { func TestUser_CreateUser_RegularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -450,12 +423,9 @@ func TestUser_CreateUser_RegularUser(t *testing.T) { func TestUser_InviteNewUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -549,13 +519,13 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = tt.serviceUser - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + tt.serviceUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, tt.serviceUser) + assert.NoError(t, err, "failed to create service user") am := DefaultAccountManager{ Store: store, @@ -582,12 +552,9 @@ func TestUser_DeleteUser_ServiceUser(t *testing.T) { func TestUser_DeleteUser_SelfDelete(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -603,39 +570,38 @@ func TestUser_DeleteUser_SelfDelete(t *testing.T) { func TestUser_DeleteUser_regularUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + }) + assert.NoError(t, err, "failed to save users") am := DefaultAccountManager{ Store: store, @@ -685,61 +651,64 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { func TestUser_DeleteUser_RegularUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - targetId := "user2" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: true, - ServiceUserName: "user2username", - } - targetId = "user3" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - } - targetId = "user4" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedIntegration, - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") - targetId = "user5" - account.Users[targetId] = &User{ - Id: targetId, - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleOwner, - } - account.Users["user6"] = &User{ - Id: "user6", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user7"] = &User{ - Id: "user7", - IsServiceUser: false, - Issued: UserIssuedAPI, - } - account.Users["user8"] = &User{ - Id: "user8", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - account.Users["user9"] = &User{ - Id: "user9", - IsServiceUser: false, - Issued: UserIssuedAPI, - Role: UserRoleAdmin, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err = store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{ + { + Id: "user2", + AccountID: mockAccountID, + IsServiceUser: true, + ServiceUserName: "user2username", + }, + { + Id: "user3", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user4", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedIntegration, + }, + { + Id: "user5", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleOwner, + }, + { + Id: "user6", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user7", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + }, + { + Id: "user8", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + { + Id: "user9", + AccountID: mockAccountID, + IsServiceUser: false, + Issued: UserIssuedAPI, + Role: UserRoleAdmin, + }, + }) + assert.NoError(t, err) am := DefaultAccountManager{ Store: store, @@ -816,7 +785,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { assert.NoError(t, err) } - acc, err := am.Store.GetAccount(context.Background(), account.Id) + acc, err := am.Store.GetAccount(context.Background(), mockAccountID) assert.NoError(t, err) for _, id := range tc.expectedDeleted { @@ -836,12 +805,9 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { func TestDefaultAccountManager_GetUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") am := DefaultAccountManager{ Store: store, @@ -865,14 +831,19 @@ func TestDefaultAccountManager_GetUser(t *testing.T) { func TestDefaultAccountManager_ListUsers(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewRegularUser("normal_user1") - account.Users["normal_user2"] = NewRegularUser("normal_user2") - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewRegularUser("normal_user1") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + newUser = NewRegularUser("normal_user2") + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -946,15 +917,25 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { store := newStore(t) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users["normal_user1"] = NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) - account.Settings.RegularUsersViewBlocked = testCase.limitedViewSettings - delete(account.Users, mockUserID) - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + newUser := NewUser("normal_user1", testCase.role, false, false, "", []string{}, UserIssuedAPI) + newUser.AccountID = mockAccountID + err = store.SaveUser(context.Background(), LockingStrengthUpdate, newUser) + assert.NoError(t, err, "failed to create user") + + settings, err := store.GetAccountSettings(context.Background(), LockingStrengthShare, mockAccountID) + assert.NoError(t, err, "failed to get account settings") + + settings.RegularUsersViewBlocked = testCase.limitedViewSettings + + err = store.SaveAccountSettings(context.Background(), LockingStrengthUpdate, mockAccountID, settings) + assert.NoError(t, err, "failed to save account settings") + + err = store.DeleteUser(context.Background(), LockingStrengthUpdate, mockAccountID, mockUserID) + assert.NoError(t, err, "failed to delete user") am := DefaultAccountManager{ Store: store, @@ -968,7 +949,7 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { assert.Equal(t, 1, len(users)) - userInfo, _ := users[0].ToUserInfo(nil, account.Settings) + userInfo, _ := users[0].ToUserInfo(nil, settings) assert.Equal(t, testCase.expectedDashboardPermissions, userInfo.Permissions.DashboardView) }) } @@ -978,22 +959,21 @@ func TestDefaultAccountManager_ListUsers_DashboardPermissions(t *testing.T) { func TestDefaultAccountManager_ExternalCache(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - externalUser := &User{ - Id: "externalUser", - Role: UserRoleUser, - Issued: UserIssuedIntegration, + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ + Id: "externalUser", + AccountID: mockAccountID, + Role: UserRoleUser, + Issued: UserIssuedIntegration, IntegrationReference: integration_reference.IntegrationReference{ ID: 1, IntegrationType: "external", }, - } - account.Users[externalUser.Id] = externalUser - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1013,6 +993,10 @@ func TestDefaultAccountManager_ExternalCache(t *testing.T) { assert.NoError(t, err) cacheManager := am.GetExternalCacheManager() + + externalUser, err := store.GetUserByUserID(context.Background(), LockingStrengthShare, "externalUser") + assert.NoError(t, err, "failed to get user") + cacheKey := externalUser.IntegrationReference.CacheKey(mockAccountID, externalUser.Id) err = cacheManager.Set(context.Background(), cacheKey, &idp.UserData{ID: externalUser.Id, Name: "Test User", Email: "user@example.com"}) assert.NoError(t, err) @@ -1042,17 +1026,17 @@ func TestUser_IsAdmin(t *testing.T) { func TestUser_GetUsersFromAccount_ForAdmin(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1071,17 +1055,16 @@ func TestUser_GetUsersFromAccount_ForUser(t *testing.T) { store := newStore(t) defer store.Close(context.Background()) - account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "") - account.Users[mockServiceUserID] = &User{ + err := newAccountWithId(context.Background(), store, mockAccountID, mockUserID, "") + assert.NoError(t, err, "failed to create account") + + err = store.SaveUser(context.Background(), LockingStrengthUpdate, &User{ Id: mockServiceUserID, + AccountID: mockAccountID, Role: "user", IsServiceUser: true, - } - - err := store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatalf("Error when saving account: %s", err) - } + }) + assert.NoError(t, err, "failed to create user") am := DefaultAccountManager{ Store: store, @@ -1240,21 +1223,30 @@ func TestDefaultAccountManager_SaveUser(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // create an account and an admin user - account, err := manager.GetOrCreateAccountByUser(context.Background(), ownerUserID, "netbird.io") + accountID, err := manager.GetOrCreateAccountIDByUser(context.Background(), ownerUserID, "netbird.io") if err != nil { t.Fatal(err) } // create other users - account.Users[regularUserID] = NewRegularUser(regularUserID) - account.Users[adminUserID] = NewAdminUser(adminUserID) - account.Users[serviceUserID] = &User{IsServiceUser: true, Id: serviceUserID, Role: UserRoleAdmin, ServiceUserName: "service"} - err = manager.Store.SaveAccount(context.Background(), account) - if err != nil { - t.Fatal(err) + regularUser := NewRegularUser(regularUserID) + regularUser.AccountID = accountID + + adminUser := NewAdminUser(adminUserID) + adminUser.AccountID = accountID + + serviceUser := &User{ + Id: serviceUserID, + AccountID: accountID, + IsServiceUser: true, + Role: UserRoleAdmin, + ServiceUserName: "service", } - updated, err := manager.SaveUser(context.Background(), account.Id, tc.initiatorID, tc.update) + err = manager.Store.SaveUsers(context.Background(), LockingStrengthUpdate, []*User{regularUser, adminUser, serviceUser}) + assert.NoError(t, err, "failed to save users") + + updated, err := manager.SaveUser(context.Background(), accountID, tc.initiatorID, tc.update) if tc.expectedErr { require.Errorf(t, err, "expecting SaveUser to throw an error") } else {