Skip to content

Commit

Permalink
ListReceivedShares now uses a JOIN
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Geens committed Jan 10, 2025
1 parent d2f446e commit 5391b34
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 51 deletions.
3 changes: 2 additions & 1 deletion share/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ type PublicLink struct {

type ShareState struct {
gorm.Model
Share Share
ShareID uint `gorm:"foreignKey:ShareID;references:ID"` // Define the foreign key field
Share Share // Define the association
// Can not be uid because of lw accs
User string
Synced bool
Expand Down
100 changes: 50 additions & 50 deletions share/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ func (m *mgr) getByKey(ctx context.Context, key *collaboration.ShareKey, checkOw
Where("shared_with_is_group = ?", key.Grantee.Type == provider.GranteeType_GRANTEE_TYPE_GROUP).
Where("share_with = ?", strings.ToLower(shareWith))

if checkOwner {
uid := conversions.FormatUserID(appctx.ContextMustGetUser(ctx).Id)
query = query.
Where("uid_owner = ? or uid_initiator = ?", uid, uid)
uid := conversions.FormatUserID(appctx.ContextMustGetUser(ctx).Id)
// In case the user is not the owner (i.e. in the case of projects)
if checkOwner && owner != uid {
query = query.Where("uid_initiator = ?", uid)
}

res := query.First(&share)
Expand Down Expand Up @@ -319,14 +319,6 @@ func (m *mgr) isProjectAdmin(u *userpb.User, path string) bool {
if g == adminGroup {
// User belongs to the admin group, list all shares for the resource

// TODO: this only works if shares for a single project are requested.
// If shares for multiple projects are requested, then we're not checking if the
// user is an admin for all of those. We can append the query ` or uid_owner=?`
// for all the project owners, which works fine for new reva
// but won't work for revaold since there, we store the uid of the share creator as uid_owner.
// For this to work across the two versions, this change would have to be made in revaold
// but it won't be straightforward as there, the storage provider doesn't return the
// resource owners.
return true
}
}
Expand Down Expand Up @@ -364,65 +356,67 @@ func (m *mgr) ListShares(ctx context.Context, filters []*collaboration.Filter) (
func (m *mgr) ListReceivedShares(ctx context.Context, filters []*collaboration.Filter) ([]*collaboration.ReceivedShare, error) {
user := appctx.ContextMustGetUser(ctx)

query := m.db.Model(&model.Share{}).
Where("orphan = ?", false)
// We need to do this to parse the result
// Normally, GORM would be able to fill in the Share that is referenced in ShareState
// However, in GORM's docs: "Join Preload will loads association data using left join"
// Because we do a RIGHT JOIN, GORM cannot load the data into shareState.Share (in case that ShareState is empty)
// So we load them both separately, and then set ShareState.Share = Share ourselves
var results []struct {
model.ShareState
model.Share
}

query := m.db.Model(&model.ShareState{}).
Select("share_states.*, shares.*").
Joins("RIGHT OUTER JOIN shares ON shares.id = share_states.share_id").
Where("shares.orphan = ?", false)

// Also search by all the groups the user is a member of
innerQuery := m.db.Where("share_with = ? and shared_with_is_group = ?", user.Username, false)
innerQuery := m.db.Where("shares.share_with = ? and shares.shared_with_is_group = ?", user.Username, false)
for _, group := range user.Groups {
innerQuery = innerQuery.Or("share_with = ? and shared_with_is_group = ?", group, true)
innerQuery = innerQuery.Or("shares.share_with = ? and shares.shared_with_is_group = ?", group, true)
}
query = query.Where(innerQuery)

// Append filters
m.appendFiltersToQuery(query, filters)

// Get the shares
var shares []model.Share
res := query.Find(&shares)
// Get the shares + states
res := query.Find(&results)
if res.Error != nil {
return nil, res.Error
}

// Now that we have the shares, we fetch the share state for every share
var receivedShares []*collaboration.ReceivedShare

for _, s := range shares {
shareId := &collaboration.ShareId{
OpaqueId: strconv.FormatUint(uint64(s.ID), 10),
}
shareState, err := m.getShareState(ctx, shareId, user)
if err != nil {
return nil, err
}
// Now we parse everything into the CS3 definition of a CS3ReceivedShare
for _, res := range results {
shareState := res.ShareState
shareState.Share = res.Share
granteeType, _ := m.getUserType(ctx, res.Share.ShareWith)

granteeType, _ := m.getUserType(ctx, s.ShareWith)

receivedShares = append(receivedShares, s.AsCS3ReceivedShare(shareState, granteeType))
receivedShares = append(receivedShares, res.Share.AsCS3ReceivedShare(&shareState, granteeType))
}

return receivedShares, nil
}

func (m *mgr) getShareState(ctx context.Context, shareId *collaboration.ShareId, user *userpb.User) (*model.ShareState, error) {
// shareId *collaboration.ShareId
func (m *mgr) getShareState(ctx context.Context, share *model.Share, user *userpb.User) (*model.ShareState, error) {
var shareState model.ShareState
query := m.db.Model(&shareState).
Where("share_id = ?", shareId.OpaqueId).
Where("share_id = ?", share.ID).
Where("user = ?", user.Username)

res := query.First(&shareState)

if res.RowsAffected == 0 {
shareIdInt, err := strconv.Atoi(shareId.OpaqueId)
if err != nil {
return nil, errors.New("Failed to fetch shareState, and failed to create one (share_id is not an int)")
}
// If no share state has been created yet, we create it now using these defaults
shareState = model.ShareState{
ShareID: uint(shareIdInt),
Hidden: false,
Synced: false,
User: user.Username,
Share: *share,
Hidden: false,
Synced: false,
User: user.Username,
}
// Does not really matter if it fails, next time the user
// lists his shares this will just be called again
Expand All @@ -439,7 +433,7 @@ func (m *mgr) getReceivedByID(ctx context.Context, id *collaboration.ShareId, gt
return nil, err
}

shareState, err := m.getShareState(ctx, id, user)
shareState, err := m.getShareState(ctx, share, user)
if err != nil {
return nil, err
}
Expand All @@ -455,11 +449,7 @@ func (m *mgr) getReceivedByKey(ctx context.Context, key *collaboration.ShareKey,
return nil, err
}

shareId := &collaboration.ShareId{
OpaqueId: strconv.Itoa(int(share.ID)),
}

shareState, err := m.getShareState(ctx, shareId, user)
shareState, err := m.getShareState(ctx, share, user)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -496,18 +486,28 @@ func (m *mgr) UpdateReceivedShare(ctx context.Context, share *collaboration.Rece

user := appctx.ContextMustGetUser(ctx)

rs, err := m.GetReceivedShare(ctx, &collaboration.ShareReference{Spec: &collaboration.ShareReference_Id{Id: share.Share.Id}})
rs, err := m.getReceivedByID(ctx, share.Share.Id, user.Id.Type)
if err != nil {
return nil, err
}

shareId, err := strconv.Atoi(share.Share.Id.OpaqueId)
if err != nil {
return nil, err
}

shareState, err := m.getShareState(ctx, share.Share.Id, user)
shareState, err := m.getShareState(ctx, &model.Share{
ProtoShare: model.ProtoShare{
Model: gorm.Model{
ID: uint(shareId),
},
},
}, user)
if err != nil {
return nil, err
}

// FieldMask determines which parts of the share we actually update
// Right now, only updating the state is supported
for _, path := range fieldMask.Paths {
switch path {
case "state":
Expand Down
1 change: 1 addition & 0 deletions share/sql/sql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ func TestListReceivedShares(t *testing.T) {

if len(receivedShares) != 1 {
t.Errorf("Expected 1 received share, got %d", len(receivedShares))
t.FailNow()
}

if receivedShares[0].Share.Id.OpaqueId != res.Id.OpaqueId {
Expand Down

0 comments on commit 5391b34

Please sign in to comment.