Skip to content

Commit

Permalink
(BIDS-3091) add subscription limit checks (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
LuccaBitfly authored Jun 6, 2024
1 parent f2e9287 commit 0b67cea
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 87 deletions.
7 changes: 1 addition & 6 deletions backend/pkg/api/data_access/data_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,16 @@ type DataAccessor interface {
ValidatorDashboardRepository
SearchRepository
NetworkRepository
UserRepository

Close()

GetLatestSlot() (uint64, error)
GetLatestExchangeRates() ([]t.EthConversionRate, error)

GetProductSummary() (*t.ProductSummary, error)
// TODO: move to user repository
GetUser(email string) (*t.User, error)
GetUserIdByApiKey(apiKey string) (uint64, error)

GetValidatorsFromSlices(indices []uint64, publicKeys []string) ([]t.VDBValidator, error)

GetUserInfo(id uint64) (*t.UserInfo, error)
GetUserDashboards(userId uint64) (*t.UserDashboardsData, error)
}

type DataAccessService struct {
Expand Down
28 changes: 26 additions & 2 deletions backend/pkg/api/data_access/dummy.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ func (d *DummyService) GetUserInfo(userId uint64) (*t.UserInfo, error) {
return &r, err
}

func (d *DummyService) GetUser(email string) (*t.User, error) {
r := t.User{}
func (d *DummyService) GetUserCredentialInfo(email string) (*t.UserCredentialInfo, error) {
r := t.UserCredentialInfo{}
err := commonFakeData(&r)
return &r, err
}
Expand Down Expand Up @@ -385,3 +385,27 @@ func (d *DummyService) GetSearchValidatorsByGraffiti(ctx context.Context, chainI
err := commonFakeData(&r)
return &r, err
}

func (d *DummyService) GetUserValidatorDashboardCount(userId uint64) (uint64, error) {
r := uint64(0)
err := commonFakeData(&r)
return r, err
}

func (d *DummyService) GetValidatorDashboardGroupCount(dashboardId t.VDBIdPrimary) (uint64, error) {
r := uint64(0)
err := commonFakeData(&r)
return r, err
}

func (d *DummyService) GetValidatorDashboardValidatorsCount(dashboardId t.VDBIdPrimary) (uint64, error) {
r := uint64(0)
err := commonFakeData(&r)
return r, err
}

func (d *DummyService) GetValidatorDashboardPublicIdCount(dashboardId t.VDBIdPrimary) (uint64, error) {
r := uint64(0)
err := commonFakeData(&r)
return r, err
}
86 changes: 84 additions & 2 deletions backend/pkg/api/data_access/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ import (
"github.com/pkg/errors"
)

func (d *DataAccessService) GetUser(email string) (*t.User, error) {
type UserRepository interface {
GetUserCredentialInfo(email string) (*t.UserCredentialInfo, error)
GetUserIdByApiKey(apiKey string) (uint64, error)
GetUserInfo(id uint64) (*t.UserInfo, error)
GetUserDashboards(userId uint64) (*t.UserDashboardsData, error)
GetUserValidatorDashboardCount(userId uint64) (uint64, error)
}

func (d *DataAccessService) GetUserCredentialInfo(email string) (*t.UserCredentialInfo, error) {
// TODO @patrick
result := &t.User{}
result := &t.UserCredentialInfo{}
err := d.userReader.Get(result, `
WITH
latest_and_greatest_sub AS (
Expand Down Expand Up @@ -399,3 +407,77 @@ func (d *DataAccessService) GetProductSummary() (*t.ProductSummary, error) {
},
}, nil
}

func (d *DataAccessService) GetUserDashboards(userId uint64) (*t.UserDashboardsData, error) {
result := &t.UserDashboardsData{}

dbReturn := []struct {
Id uint64 `db:"id"`
Name string `db:"name"`
PublicId sql.NullString `db:"public_id"`
PublicName sql.NullString `db:"public_name"`
SharedGroups sql.NullBool `db:"shared_groups"`
}{}

// Get the validator dashboards including the public ones
err := d.alloyReader.Select(&dbReturn, `
SELECT
uvd.id,
uvd.name,
uvds.public_id,
uvds.name AS public_name,
uvds.shared_groups
FROM users_val_dashboards uvd
LEFT JOIN users_val_dashboards_sharing uvds ON uvd.id = uvds.dashboard_id
WHERE uvd.user_id = $1
`, userId)
if err != nil {
return nil, err
}

// Fill the result
validatorDashboardMap := make(map[uint64]*t.ValidatorDashboard, 0)
for _, row := range dbReturn {
if _, ok := validatorDashboardMap[row.Id]; !ok {
validatorDashboardMap[row.Id] = &t.ValidatorDashboard{
Id: row.Id,
Name: row.Name,
PublicIds: []t.VDBPublicId{},
}
}
if row.PublicId.Valid {
result := t.VDBPublicId{}
result.PublicId = row.PublicId.String
result.Name = row.PublicName.String
result.ShareSettings.ShareGroups = row.SharedGroups.Bool

validatorDashboardMap[row.Id].PublicIds = append(validatorDashboardMap[row.Id].PublicIds, result)
}
}
for _, validatorDashboard := range validatorDashboardMap {
result.ValidatorDashboards = append(result.ValidatorDashboards, *validatorDashboard)
}

// Get the account dashboards
err = d.alloyReader.Select(&result.AccountDashboards, `
SELECT
id,
name
FROM users_acc_dashboards
WHERE user_id = $1
`, userId)
if err != nil {
return nil, err
}

return result, nil
}

func (d *DataAccessService) GetUserValidatorDashboardCount(userId uint64) (uint64, error) {
var count uint64
err := d.alloyReader.Get(&count, `
SELECT COUNT(*) FROM users_val_dashboards
WHERE user_id = $1
`, userId)
return count, err
}
5 changes: 4 additions & 1 deletion backend/pkg/api/data_access/vdb_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ type ValidatorDashboardRepository interface {
CreateValidatorDashboardGroup(dashboardId t.VDBIdPrimary, name string) (*t.VDBPostCreateGroupData, error)
UpdateValidatorDashboardGroup(dashboardId t.VDBIdPrimary, groupId uint64, name string) (*t.VDBPostCreateGroupData, error)
RemoveValidatorDashboardGroup(dashboardId t.VDBIdPrimary, groupId uint64) error

GetValidatorDashboardGroupCount(dashboardId t.VDBIdPrimary) (uint64, error)
GetValidatorDashboardGroupExists(dashboardId t.VDBIdPrimary, groupId uint64) (bool, error)

GetValidatorDashboardExistingValidatorCount(dashboardId t.VDBIdPrimary, validators []t.VDBValidator) (uint64, error)
AddValidatorDashboardValidators(dashboardId t.VDBIdPrimary, groupId uint64, validators []t.VDBValidator) ([]t.VDBPostValidatorsData, error)
AddValidatorDashboardValidatorsByDepositAddress(dashboardId t.VDBIdPrimary, groupId uint64, address string, limit uint64) ([]t.VDBPostValidatorsData, error)
Expand All @@ -33,11 +34,13 @@ type ValidatorDashboardRepository interface {

RemoveValidatorDashboardValidators(dashboardId t.VDBIdPrimary, validators []t.VDBValidator) error
GetValidatorDashboardValidators(dashboardId t.VDBId, groupId int64, cursor string, colSort t.Sort[enums.VDBManageValidatorsColumn], search string, limit uint64) ([]t.VDBManageValidatorsTableRow, *t.Paging, error)
GetValidatorDashboardValidatorsCount(dashboardId t.VDBIdPrimary) (uint64, error)

CreateValidatorDashboardPublicId(dashboardId t.VDBIdPrimary, name string, shareGroups bool) (*t.VDBPublicId, error)
GetValidatorDashboardPublicId(publicDashboardId t.VDBIdPublic) (*t.VDBPublicId, error)
UpdateValidatorDashboardPublicId(publicDashboardId t.VDBIdPublic, name string, shareGroups bool) (*t.VDBPublicId, error)
RemoveValidatorDashboardPublicId(publicDashboardId t.VDBIdPublic) error
GetValidatorDashboardPublicIdCount(dashboardId t.VDBIdPrimary) (uint64, error)

GetValidatorDashboardSlotViz(dashboardId t.VDBId) ([]t.SlotVizEpoch, error)

Expand Down
93 changes: 28 additions & 65 deletions backend/pkg/api/data_access/vdb_management.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,71 +96,6 @@ func (d *DataAccessService) GetValidatorsFromSlices(indices []t.VDBValidator, pu
return result, nil
}

func (d *DataAccessService) GetUserDashboards(userId uint64) (*t.UserDashboardsData, error) {
result := &t.UserDashboardsData{}

dbReturn := []struct {
Id uint64 `db:"id"`
Name string `db:"name"`
PublicId sql.NullString `db:"public_id"`
PublicName sql.NullString `db:"public_name"`
SharedGroups sql.NullBool `db:"shared_groups"`
}{}

// Get the validator dashboards including the public ones
err := d.alloyReader.Select(&dbReturn, `
SELECT
uvd.id,
uvd.name,
uvds.public_id,
uvds.name AS public_name,
uvds.shared_groups
FROM users_val_dashboards uvd
LEFT JOIN users_val_dashboards_sharing uvds ON uvd.id = uvds.dashboard_id
WHERE uvd.user_id = $1
`, userId)
if err != nil {
return nil, err
}

// Fill the result
validatorDashboardMap := make(map[uint64]*t.ValidatorDashboard, 0)
for _, row := range dbReturn {
if _, ok := validatorDashboardMap[row.Id]; !ok {
validatorDashboardMap[row.Id] = &t.ValidatorDashboard{
Id: row.Id,
Name: row.Name,
PublicIds: []t.VDBPublicId{},
}
}
if row.PublicId.Valid {
result := t.VDBPublicId{}
result.PublicId = row.PublicId.String
result.Name = row.PublicName.String
result.ShareSettings.ShareGroups = row.SharedGroups.Bool

validatorDashboardMap[row.Id].PublicIds = append(validatorDashboardMap[row.Id].PublicIds, result)
}
}
for _, validatorDashboard := range validatorDashboardMap {
result.ValidatorDashboards = append(result.ValidatorDashboards, *validatorDashboard)
}

// Get the account dashboards
err = d.alloyReader.Select(&result.AccountDashboards, `
SELECT
id,
name
FROM users_acc_dashboards
WHERE user_id = $1
`, userId)
if err != nil {
return nil, err
}

return result, nil
}

func (d *DataAccessService) CreateValidatorDashboard(userId uint64, name string, network uint64) (*t.VDBPostReturnData, error) {
result := &t.VDBPostReturnData{}

Expand Down Expand Up @@ -488,6 +423,14 @@ func (d *DataAccessService) RemoveValidatorDashboardGroup(dashboardId t.VDBIdPri
return nil
}

func (d *DataAccessService) GetValidatorDashboardGroupCount(dashboardId t.VDBIdPrimary) (uint64, error) {
var count uint64
err := d.alloyReader.Get(&count, `
SELECT COUNT(*) FROM users_val_dashboards_groups WHERE dashboard_id = $1
`, dashboardId)
return count, err
}

func (d *DataAccessService) GetValidatorDashboardValidators(dashboardId t.VDBId, groupId int64, cursor string, colSort t.Sort[enums.VDBManageValidatorsColumn], search string, limit uint64) ([]t.VDBManageValidatorsTableRow, *t.Paging, error) {
// Initialize the cursor
var currentCursor t.ValidatorsCursor
Expand Down Expand Up @@ -878,6 +821,16 @@ func (d *DataAccessService) RemoveValidatorDashboardValidators(dashboardId t.VDB
return err
}

func (d *DataAccessService) GetValidatorDashboardValidatorsCount(dashboardId t.VDBIdPrimary) (uint64, error) {
var count uint64
err := d.alloyReader.Get(&count, `
SELECT COUNT(*)
FROM users_val_dashboards_validators
WHERE dashboard_id = $1
`, dashboardId)
return count, err
}

func (d *DataAccessService) CreateValidatorDashboardPublicId(dashboardId t.VDBIdPrimary, name string, shareGroups bool) (*t.VDBPublicId, error) {
dbReturn := struct {
PublicId string `db:"public_id"`
Expand Down Expand Up @@ -980,3 +933,13 @@ func (d *DataAccessService) RemoveValidatorDashboardPublicId(publicDashboardId t

return err
}

func (d *DataAccessService) GetValidatorDashboardPublicIdCount(dashboardId t.VDBIdPrimary) (uint64, error) {
var count uint64
err := d.alloyReader.Get(&count, `
SELECT COUNT(*)
FROM users_val_dashboards_sharing
WHERE dashboard_id = $1
`, dashboardId)
return count, err
}
20 changes: 15 additions & 5 deletions backend/pkg/api/handlers/auth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package handlers

import (
"context"
"errors"
"net/http"
"strconv"
Expand All @@ -18,19 +19,23 @@ const (
userGroupKey = "user_group"
)

func (h *HandlerService) getUserBySession(r *http.Request) (types.User, error) {
type ctxKet string

const ctxUserIdKey ctxKet = "user_id"

func (h *HandlerService) getUserBySession(r *http.Request) (types.UserCredentialInfo, error) {
authenticated := h.scs.GetBool(r.Context(), authenticatedKey)
if !authenticated {
return types.User{}, newUnauthorizedErr("not authenticated")
return types.UserCredentialInfo{}, newUnauthorizedErr("not authenticated")
}
subscription := h.scs.GetString(r.Context(), subscriptionKey)
userGroup := h.scs.GetString(r.Context(), userGroupKey)
userId, ok := h.scs.Get(r.Context(), userIdKey).(uint64)
if !ok {
return types.User{}, errors.New("error parsind user id from session, not a uint64")
return types.UserCredentialInfo{}, errors.New("error parsind user id from session, not a uint64")
}

return types.User{
return types.UserCredentialInfo{
Id: userId,
ProductId: subscription,
UserGroup: userGroup,
Expand Down Expand Up @@ -91,7 +96,7 @@ func (h *HandlerService) InternalPostLogin(w http.ResponseWriter, r *http.Reques

badCredentialsErr := newUnauthorizedErr("invalid email or password")
// fetch user
user, err := h.dai.GetUser(email)
user, err := h.dai.GetUserCredentialInfo(email)
if err != nil {
if errors.Is(err, dataaccess.ErrNotFound) {
err = badCredentialsErr
Expand Down Expand Up @@ -152,6 +157,11 @@ func (h *HandlerService) GetVDBAuthMiddleware(userIdFunc func(r *http.Request) (
handleErr(w, err)
return
}
// store user id in context
ctx := r.Context()
ctx = context.WithValue(ctx, ctxUserIdKey, userId)
r = r.WithContext(ctx)

dashboard, err := h.dai.GetValidatorDashboardInfo(types.VDBIdPrimary(dashboardId))
if err != nil {
handleErr(w, err)
Expand Down
Loading

0 comments on commit 0b67cea

Please sign in to comment.