From 7843bfb9e3ae20ddce30a0ab2fed5453e6d9683d Mon Sep 17 00:00:00 2001 From: Jesse Geens Date: Fri, 10 Jan 2025 16:38:56 +0100 Subject: [PATCH] ListReceivedShares now uses a JOIN --- share/model.go | 3 +- share/sql/sql.go | 99 ++++++++++++++++++++++--------------------- share/sql/sql_test.go | 1 + 3 files changed, 53 insertions(+), 50 deletions(-) diff --git a/share/model.go b/share/model.go index 56d4609..ac65af2 100644 --- a/share/model.go +++ b/share/model.go @@ -67,7 +67,8 @@ type PublicLink struct { type ShareState struct { gorm.Model - Share Share + ShareID uint `gorm:"foreignKey:ShareID;references:ID"` // Define the foreign key field + Share Share // Define the association // Can not be uid because of lw accs User string Synced bool diff --git a/share/sql/sql.go b/share/sql/sql.go index dd8d443..81c86f7 100644 --- a/share/sql/sql.go +++ b/share/sql/sql.go @@ -201,10 +201,10 @@ func (m *mgr) getByKey(ctx context.Context, key *collaboration.ShareKey, checkOw Where("shared_with_is_group = ?", key.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_GROUP). Where("share_with = ?", strings.ToLower(shareWith)) - if checkOwner { - uid := conversions.FormatUserID(appctx.ContextMustGetUser(ctx).Id) - query = query. - Where("uid_owner = ? or uid_initiator = ?", uid, uid) + uid := conversions.FormatUserID(appctx.ContextMustGetUser(ctx).Id) + // In case the user is not the owner (i.e. in the case of projects) + if checkOwner && owner != uid { + query = query.Where("uid_initiator = ?", uid) } res := query.First(&share) @@ -319,14 +319,6 @@ func (m *mgr) isProjectAdmin(u *userpb.User, path string) bool { if g == adminGroup { // User belongs to the admin group, list all shares for the resource - // TODO: this only works if shares for a single project are requested. - // If shares for multiple projects are requested, then we're not checking if the - // user is an admin for all of those. We can append the query ` or uid_owner=?` - // for all the project owners, which works fine for new reva - // but won't work for revaold since there, we store the uid of the share creator as uid_owner. - // For this to work across the two versions, this change would have to be made in revaold - // but it won't be straightforward as there, the storage provider doesn't return the - // resource owners. return true } } @@ -364,65 +356,67 @@ func (m *mgr) ListShares(ctx context.Context, filters []*collaboration.Filter) ( func (m *mgr) ListReceivedShares(ctx context.Context, filters []*collaboration.Filter) ([]*collaboration.ReceivedShare, error) { user := appctx.ContextMustGetUser(ctx) - query := m.db.Model(&model.Share{}). - Where("orphan = ?", false) + // We need to do this to parse the result + // Normally, GORM would be able to fill in the Share that is referenced in ShareState + // However, in GORM's docs: "Join Preload will loads association data using left join" + // Because we do a RIGHT JOIN, GORM cannot load the data into shareState.Share (in case that ShareState is empty) + // So we load them both separately, and then set ShareState.Share = Share ourselves + var results []struct { + model.ShareState + model.Share + } + + query := m.db.Model(&model.ShareState{}). + Select("share_states.*, shares.*"). + Joins("RIGHT OUTER JOIN shares ON shares.id = share_states.share_id"). + Where("shares.orphan = ?", false) // Also search by all the groups the user is a member of - innerQuery := m.db.Where("share_with = ? and shared_with_is_group = ?", user.Username, false) + innerQuery := m.db.Where("shares.share_with = ? and shares.shared_with_is_group = ?", user.Username, false) for _, group := range user.Groups { - innerQuery = innerQuery.Or("share_with = ? and shared_with_is_group = ?", group, true) + innerQuery = innerQuery.Or("shares.share_with = ? and shares.shared_with_is_group = ?", group, true) } query = query.Where(innerQuery) // Append filters m.appendFiltersToQuery(query, filters) - // Get the shares - var shares []model.Share - res := query.Find(&shares) + // Get the shares + states + res := query.Find(&results) if res.Error != nil { return nil, res.Error } - // Now that we have the shares, we fetch the share state for every share var receivedShares []*collaboration.ReceivedShare - for _, s := range shares { - shareId := &collaboration.ShareId{ - OpaqueId: strconv.FormatUint(uint64(s.ID), 10), - } - shareState, err := m.getShareState(ctx, shareId, user) - if err != nil { - return nil, err - } + // Now we parse everything into the CS3 definition of a CS3ReceivedShare + for _, res := range results { + shareState := res.ShareState + shareState.Share = res.Share + granteeType, _ := m.getUserType(ctx, res.Share.ShareWith) - granteeType, _ := m.getUserType(ctx, s.ShareWith) - - receivedShares = append(receivedShares, s.AsCS3ReceivedShare(shareState, granteeType)) + receivedShares = append(receivedShares, res.Share.AsCS3ReceivedShare(&shareState, granteeType)) } return receivedShares, nil } -func (m *mgr) getShareState(ctx context.Context, shareId *collaboration.ShareId, user *userpb.User) (*model.ShareState, error) { +// shareId *collaboration.ShareId +func (m *mgr) getShareState(ctx context.Context, share *model.Share, user *userpb.User) (*model.ShareState, error) { var shareState model.ShareState query := m.db.Model(&shareState). - Where("share_id = ?", shareId.OpaqueId). + Where("share_id = ?", share.ID). Where("user = ?", user.Username) res := query.First(&shareState) if res.RowsAffected == 0 { - shareIdInt, err := strconv.Atoi(shareId.OpaqueId) - if err != nil { - return nil, errors.New("Failed to fetch shareState, and failed to create one (share_id is not an int)") - } // If no share state has been created yet, we create it now using these defaults shareState = model.ShareState{ - ShareID: uint(shareIdInt), - Hidden: false, - Synced: false, - User: user.Username, + Share: *share, + Hidden: false, + Synced: false, + User: user.Username, } // Does not really matter if it fails, next time the user // lists his shares this will just be called again @@ -439,7 +433,7 @@ func (m *mgr) getReceivedByID(ctx context.Context, id *collaboration.ShareId, gt return nil, err } - shareState, err := m.getShareState(ctx, id, user) + shareState, err := m.getShareState(ctx, share, user) if err != nil { return nil, err } @@ -455,11 +449,7 @@ func (m *mgr) getReceivedByKey(ctx context.Context, key *collaboration.ShareKey, return nil, err } - shareId := &collaboration.ShareId{ - OpaqueId: strconv.Itoa(int(share.ID)), - } - - shareState, err := m.getShareState(ctx, shareId, user) + shareState, err := m.getShareState(ctx, share, user) if err != nil { return nil, err } @@ -496,12 +486,23 @@ func (m *mgr) UpdateReceivedShare(ctx context.Context, share *collaboration.Rece user := appctx.ContextMustGetUser(ctx) - rs, err := m.GetReceivedShare(ctx, &collaboration.ShareReference{Spec: &collaboration.ShareReference_Id{Id: share.Share.Id}}) + rs, err := m.getReceivedByID(ctx, share.Share.Id, user.Id.Type) + if err != nil { + return nil, err + } + + shareId, err := strconv.Atoi(share.Share.Id.OpaqueId) if err != nil { return nil, err } - shareState, err := m.getShareState(ctx, share.Share.Id, user) + shareState, err := m.getShareState(ctx, &model.Share{ + ProtoShare: model.ProtoShare{ + Model: gorm.Model{ + ID: uint(shareId), + }, + }, + }, user) if err != nil { return nil, err } diff --git a/share/sql/sql_test.go b/share/sql/sql_test.go index c85d5a3..dbe738a 100644 --- a/share/sql/sql_test.go +++ b/share/sql/sql_test.go @@ -363,6 +363,7 @@ func TestListReceivedShares(t *testing.T) { if len(receivedShares) != 1 { t.Errorf("Expected 1 received share, got %d", len(receivedShares)) + t.FailNow() } if receivedShares[0].Share.Id.OpaqueId != res.Id.OpaqueId {