diff --git a/api/custom/purge.go b/api/custom/purge.go index 4111d3f7..4d4ff3b9 100644 --- a/api/custom/purge.go +++ b/api/custom/purge.go @@ -1,7 +1,6 @@ package custom import ( - "database/sql" "net/http" "strconv" @@ -9,15 +8,13 @@ import ( "github.com/turt2live/matrix-media-repo/api/_apimeta" "github.com/turt2live/matrix-media-repo/api/_responses" "github.com/turt2live/matrix-media-repo/api/_routers" + "github.com/turt2live/matrix-media-repo/database" "github.com/turt2live/matrix-media-repo/tasks/task_runner" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/common/rcontext" - "github.com/turt2live/matrix-media-repo/controllers/maintenance_controller" "github.com/turt2live/matrix-media-repo/matrix" - "github.com/turt2live/matrix-media-repo/storage" - "github.com/turt2live/matrix-media-repo/types" "github.com/turt2live/matrix-media-repo/util" ) @@ -51,8 +48,7 @@ func PurgeRemoteMedia(r *http.Request, rctx rcontext.RequestContext, user _apime } func PurgeIndividualRecord(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user) - localServerName := r.Host + authCtx, _, _ := getPurgeAuthContext(rctx, r, user) server := _routers.GetParam("server", r) mediaId := _routers.GetParam("mediaId", r) @@ -66,66 +62,54 @@ func PurgeIndividualRecord(r *http.Request, rctx rcontext.RequestContext, user _ "mediaId": mediaId, }) - // If the user is NOT a global admin, ensure they are speaking to the right server - if !isGlobalAdmin { - if server != localServerName { + _, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{ + Single: &task_runner.QuarantineRecord{ + Origin: server, + MediaId: mediaId, + }, + }) + if err != nil { + if err == common.ErrWrongUser { return _responses.AuthFailed() } - // If the user is NOT a local admin, ensure they uploaded the content in the first place - if !isLocalAdmin { - db := storage.GetDatabase().GetMediaStore(rctx) - m, err := db.Get(server, mediaId) - if err == sql.ErrNoRows { - return _responses.NotFoundError() - } - if err != nil { - rctx.Log.Error("Error checking ownership of media: ", err) - sentry.CaptureException(err) - return _responses.InternalServerError("error checking media ownership") - } - if m.UserId != user.UserId { - return _responses.AuthFailed() - } - } - } - - err := maintenance_controller.PurgeMedia(server, mediaId, rctx) - if err == sql.ErrNoRows || err == common.ErrMediaNotFound { - return _responses.NotFoundError() - } - if err != nil { - rctx.Log.Error("Error purging media: ", err) + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("unexpected error") } return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true}} } func PurgeQuarantined(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user) - localServerName := r.Host + authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user) - var affected []*types.Media + var affected []*database.DbMedia var err error + mediaDb := database.GetInstance().Media.Prepare(rctx) if isGlobalAdmin { - affected, err = maintenance_controller.PurgeQuarantined(rctx) + affected, err = mediaDb.GetByQuarantine() } else if isLocalAdmin { - affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, rctx) + affected, err = mediaDb.GetByOriginQuarantine(r.Host) } else { return _responses.AuthFailed() } - if err != nil { - rctx.Log.Error("Error purging media: ", err) + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("error fetching media records") } - mxcs := make([]string, 0) - for _, a := range affected { - mxcs = append(mxcs, a.MxcUri()) + mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{ + DbMedia: affected, + }) + if err != nil { + if err == common.ErrWrongUser { + return _responses.AuthFailed() + } + rctx.Log.Error(err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") } return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} @@ -156,24 +140,36 @@ func PurgeOldMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta. "include_local": includeLocal, }) - affected, err := maintenance_controller.PurgeOldMedia(beforeTs, includeLocal, rctx) + domains := make([]string, 0) + if !includeLocal { + domains = util.GetOurDomains() + } + mediaDb := database.GetInstance().Media.Prepare(rctx) + records, err := mediaDb.GetOldExcluding(domains, beforeTs) if err != nil { - rctx.Log.Error("Error purging media: ", err) + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("error fetching media records") } - mxcs := make([]string, 0) - for _, a := range affected { - mxcs = append(mxcs, a.MxcUri()) + mxcs, err := task_runner.PurgeMedia(rctx, &task_runner.PurgeAuthContext{}, &task_runner.QuarantineThis{ + DbMedia: records, + }) + if err != nil { + if err == common.ErrWrongUser { + return _responses.AuthFailed() + } + rctx.Log.Error(err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") } return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } func PurgeUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user) + authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user) if !isGlobalAdmin && !isLocalAdmin { return _responses.AuthFailed() } @@ -206,24 +202,31 @@ func PurgeUserMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta return _responses.AuthFailed() } - affected, err := maintenance_controller.PurgeUserMedia(userId, beforeTs, rctx) - + mediaDb := database.GetInstance().Media.Prepare(rctx) + records, err := mediaDb.GetOldByUserId(userId, beforeTs) if err != nil { - rctx.Log.Error("Error purging media: ", err) + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("error fetching media records") } - mxcs := make([]string, 0) - for _, a := range affected { - mxcs = append(mxcs, a.MxcUri()) + mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{ + DbMedia: records, + }) + if err != nil { + if err == common.ErrWrongUser { + return _responses.AuthFailed() + } + rctx.Log.Error(err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") } return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } func PurgeRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user) + authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user) if !isGlobalAdmin && !isLocalAdmin { return _responses.AuthFailed() } @@ -276,32 +279,27 @@ func PurgeRoomMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta mxcs = append(mxcs, mxc) } } else { - for _, mxc := range allMedia.LocalMxcs { - mxcs = append(mxcs, mxc) - } - for _, mxc := range allMedia.RemoteMxcs { - mxcs = append(mxcs, mxc) - } + mxcs = append(mxcs, allMedia.LocalMxcs...) + mxcs = append(mxcs, allMedia.RemoteMxcs...) } - affected, err := maintenance_controller.PurgeRoomMedia(mxcs, beforeTs, rctx) - + mxcs2, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{ + MxcUris: mxcs, + }) if err != nil { - rctx.Log.Error("Error purging media: ", err) + if err == common.ErrWrongUser { + return _responses.AuthFailed() + } + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("unexpected error") } - mxcs = make([]string, 0) - for _, a := range affected { - mxcs = append(mxcs, a.MxcUri()) - } - - return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} + return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs2}} } func PurgeDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apimeta.UserInfo) interface{} { - isGlobalAdmin, isLocalAdmin := _apimeta.GetRequestUserAdminStatus(r, rctx, user) + authCtx, isGlobalAdmin, isLocalAdmin := getPurgeAuthContext(rctx, r, user) if !isGlobalAdmin && !isLocalAdmin { return _responses.AuthFailed() } @@ -331,18 +329,36 @@ func PurgeDomainMedia(r *http.Request, rctx rcontext.RequestContext, user _apime return _responses.AuthFailed() } - affected, err := maintenance_controller.PurgeDomainMedia(serverName, beforeTs, rctx) - + mediaDb := database.GetInstance().Media.Prepare(rctx) + records, err := mediaDb.GetOldByOrigin(serverName, beforeTs) if err != nil { - rctx.Log.Error("Error purging media: ", err) + rctx.Log.Error(err) sentry.CaptureException(err) - return _responses.InternalServerError("error purging media") + return _responses.InternalServerError("error fetching media records") } - mxcs := make([]string, 0) - for _, a := range affected { - mxcs = append(mxcs, a.MxcUri()) + mxcs, err := task_runner.PurgeMedia(rctx, authCtx, &task_runner.QuarantineThis{ + DbMedia: records, + }) + if err != nil { + if err == common.ErrWrongUser { + return _responses.AuthFailed() + } + rctx.Log.Error(err) + sentry.CaptureException(err) + return _responses.InternalServerError("unexpected error") } return &_responses.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } + +func getPurgeAuthContext(ctx rcontext.RequestContext, r *http.Request, user _apimeta.UserInfo) (*task_runner.PurgeAuthContext, bool, bool) { + globalAdmin, localAdmin := _apimeta.GetRequestUserAdminStatus(r, ctx, user) + if globalAdmin { + return &task_runner.PurgeAuthContext{}, true, localAdmin + } + if localAdmin { + return &task_runner.PurgeAuthContext{SourceOrigin: r.Host}, false, true + } + return &task_runner.PurgeAuthContext{UploaderUserId: user.UserId}, false, false +} diff --git a/database/table_media.go b/database/table_media.go index d3f1cbb2..edb811ec 100644 --- a/database/table_media.go +++ b/database/table_media.go @@ -36,7 +36,9 @@ const insertMedia = "INSERT INTO media (origin, media_id, upload_name, content_t const selectMediaExists = "SELECT TRUE FROM media WHERE origin = $1 AND media_id = $2 LIMIT 1;" const selectMediaById = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND media_id = $2;" const selectMediaByUserId = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE user_id = $1;" +const selectOldMediaByUserId = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE user_id = $1 AND creation_ts < $2;" const selectMediaByOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1;" +const selectOldMediaByOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND creation_ts < $2;" const selectMediaByLocationExists = "SELECT TRUE FROM media WHERE datastore_id = $1 AND location = $2 LIMIT 1;" const selectMediaByUserCount = "SELECT COUNT(*) FROM media WHERE user_id = $1;" const selectMediaByOriginAndUserIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE origin = $1 AND user_id = ANY($2);" @@ -44,23 +46,31 @@ const selectMediaByOriginAndIds = "SELECT origin, media_id, upload_name, content const selectOldMediaExcludingDomains = "SELECT m.origin, m.media_id, m.upload_name, m.content_type, m.user_id, m.sha256_hash, m.size_bytes, m.creation_ts, m.quarantined, m.datastore_id, m.location FROM media AS m WHERE m.origin <> ANY($1) AND m.creation_ts < $2 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.creation_ts >= $2) = 0 AND (SELECT COUNT(d.*) FROM media AS d WHERE d.sha256_hash = m.sha256_hash AND d.origin = ANY($1)) = 0;" const deleteMedia = "DELETE FROM media WHERE origin = $1 AND media_id = $2;" const updateMediaLocation = "UPDATE media SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;" +const selectMediaByLocation = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE datastore_id = $1 AND location = $2;" +const selectMediaByQuarantine = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE quarantined = TRUE;" +const selectMediaByQuarantineAndOrigin = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, creation_ts, quarantined, datastore_id, location FROM media WHERE quarantined = TRUE AND origin = $1;" type mediaTableStatements struct { - selectDistinctMediaDatastoreIds *sql.Stmt - selectMediaIsQuarantinedByHash *sql.Stmt - selectMediaByHash *sql.Stmt - insertMedia *sql.Stmt - selectMediaExists *sql.Stmt - selectMediaById *sql.Stmt - selectMediaByUserId *sql.Stmt - selectMediaByOrigin *sql.Stmt - selectMediaByLocationExists *sql.Stmt - selectMediaByUserCount *sql.Stmt - selectMediaByOriginAndUserIds *sql.Stmt - selectMediaByOriginAndIds *sql.Stmt - selectOldMediaExcludingDomains *sql.Stmt - deleteMedia *sql.Stmt - updateMediaLocation *sql.Stmt + selectDistinctMediaDatastoreIds *sql.Stmt + selectMediaIsQuarantinedByHash *sql.Stmt + selectMediaByHash *sql.Stmt + insertMedia *sql.Stmt + selectMediaExists *sql.Stmt + selectMediaById *sql.Stmt + selectMediaByUserId *sql.Stmt + selectOldMediaByUserId *sql.Stmt + selectMediaByOrigin *sql.Stmt + selectOldMediaByOrigin *sql.Stmt + selectMediaByLocationExists *sql.Stmt + selectMediaByUserCount *sql.Stmt + selectMediaByOriginAndUserIds *sql.Stmt + selectMediaByOriginAndIds *sql.Stmt + selectOldMediaExcludingDomains *sql.Stmt + deleteMedia *sql.Stmt + updateMediaLocation *sql.Stmt + selectMediaByLocation *sql.Stmt + selectMediaByQuarantine *sql.Stmt + selectMediaByQuarantineAndOrigin *sql.Stmt } type mediaTableWithContext struct { @@ -93,9 +103,15 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) { if stmts.selectMediaByUserId, err = db.Prepare(selectMediaByUserId); err != nil { return nil, errors.New("error preparing selectMediaByUserId: " + err.Error()) } + if stmts.selectOldMediaByUserId, err = db.Prepare(selectOldMediaByUserId); err != nil { + return nil, errors.New("error preparing selectOldMediaByUserId: " + err.Error()) + } if stmts.selectMediaByOrigin, err = db.Prepare(selectMediaByOrigin); err != nil { return nil, errors.New("error preparing selectMediaByOrigin: " + err.Error()) } + if stmts.selectOldMediaByOrigin, err = db.Prepare(selectOldMediaByOrigin); err != nil { + return nil, errors.New("error preparing selectOldMediaByOrigin: " + err.Error()) + } if stmts.selectMediaByLocationExists, err = db.Prepare(selectMediaByLocationExists); err != nil { return nil, errors.New("error preparing selectMediaByLocationExists: " + err.Error()) } @@ -117,6 +133,15 @@ func prepareMediaTables(db *sql.DB) (*mediaTableStatements, error) { if stmts.updateMediaLocation, err = db.Prepare(updateMediaLocation); err != nil { return nil, errors.New("error preparing updateMediaLocation: " + err.Error()) } + if stmts.selectMediaByLocation, err = db.Prepare(selectMediaByLocation); err != nil { + return nil, errors.New("error preparing selectMediaByLocation: " + err.Error()) + } + if stmts.selectMediaByQuarantine, err = db.Prepare(selectMediaByQuarantine); err != nil { + return nil, errors.New("error preparing selectMediaByQuarantine: " + err.Error()) + } + if stmts.selectMediaByQuarantineAndOrigin, err = db.Prepare(selectMediaByQuarantineAndOrigin); err != nil { + return nil, errors.New("error preparing selectMediaByQuarantineAndOrigin: " + err.Error()) + } return stmts, nil } @@ -188,10 +213,18 @@ func (s *mediaTableWithContext) GetByUserId(userId string) ([]*DbMedia, error) { return s.scanRows(s.statements.selectMediaByUserId.QueryContext(s.ctx, userId)) } +func (s *mediaTableWithContext) GetOldByUserId(userId string, beforeTs int64) ([]*DbMedia, error) { + return s.scanRows(s.statements.selectOldMediaByUserId.QueryContext(s.ctx, userId, beforeTs)) +} + func (s *mediaTableWithContext) GetByOrigin(origin string) ([]*DbMedia, error) { return s.scanRows(s.statements.selectMediaByOrigin.QueryContext(s.ctx, origin)) } +func (s *mediaTableWithContext) GetOldByOrigin(origin string, beforeTs int64) ([]*DbMedia, error) { + return s.scanRows(s.statements.selectOldMediaByOrigin.QueryContext(s.ctx, origin, beforeTs)) +} + func (s *mediaTableWithContext) GetByOriginUsers(origin string, userIds []string) ([]*DbMedia, error) { return s.scanRows(s.statements.selectMediaByOriginAndUserIds.QueryContext(s.ctx, origin, pq.Array(userIds))) } @@ -204,6 +237,18 @@ func (s *mediaTableWithContext) GetOldExcluding(origins []string, beforeTs int64 return s.scanRows(s.statements.selectOldMediaExcludingDomains.QueryContext(s.ctx, pq.Array(origins), beforeTs)) } +func (s *mediaTableWithContext) GetByLocation(datastoreId string, location string) ([]*DbMedia, error) { + return s.scanRows(s.statements.selectMediaByLocation.QueryContext(s.ctx, datastoreId, location)) +} + +func (s *mediaTableWithContext) GetByQuarantine() ([]*DbMedia, error) { + return s.scanRows(s.statements.selectMediaByQuarantine.QueryContext(s.ctx)) +} + +func (s *mediaTableWithContext) GetByOriginQuarantine(origin string) ([]*DbMedia, error) { + return s.scanRows(s.statements.selectMediaByQuarantineAndOrigin.QueryContext(s.ctx, origin)) +} + func (s *mediaTableWithContext) GetById(origin string, mediaId string) (*DbMedia, error) { row := s.statements.selectMediaById.QueryRowContext(s.ctx, origin, mediaId) val := &DbMedia{Locatable: &Locatable{}} diff --git a/database/table_reserved_media.go b/database/table_reserved_media.go index 239b0976..58047c68 100644 --- a/database/table_reserved_media.go +++ b/database/table_reserved_media.go @@ -13,10 +13,12 @@ type DbReservedMedia struct { Reason string } -const insertReservedMedia = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3);" +const insertReservedMediaNoConflict = "INSERT INTO reserved_media (origin, media_id, reason) VALUES ($1, $2, $3) ON CONFLICT (origin, media_id) DO NOTHING;" +const selectReservedMediaExists = "SELECT TRUE FROM reserved_media WHERE origin = $1 AND media_id = $2 LIMIT 1;" type reservedMediaTableStatements struct { - insertReservedMedia *sql.Stmt + insertReservedMediaNoConflict *sql.Stmt + selectReservedMediaExists *sql.Stmt } type reservedMediaTableWithContext struct { @@ -28,8 +30,11 @@ func prepareReservedMediaTables(db *sql.DB) (*reservedMediaTableStatements, erro var err error var stmts = &reservedMediaTableStatements{} - if stmts.insertReservedMedia, err = db.Prepare(insertReservedMedia); err != nil { - return nil, errors.New("error preparing insertReservedMedia: " + err.Error()) + if stmts.insertReservedMediaNoConflict, err = db.Prepare(insertReservedMediaNoConflict); err != nil { + return nil, errors.New("error preparing insertReservedMediaNoConflict: " + err.Error()) + } + if stmts.selectReservedMediaExists, err = db.Prepare(selectReservedMediaExists); err != nil { + return nil, errors.New("error preparing selectReservedMediaExists: " + err.Error()) } return stmts, nil @@ -42,7 +47,18 @@ func (s *reservedMediaTableStatements) Prepare(ctx rcontext.RequestContext) *res } } -func (s *reservedMediaTableWithContext) TryInsert(origin string, mediaId string, reason string) error { - _, err := s.statements.insertReservedMedia.ExecContext(s.ctx, origin, mediaId, reason) +func (s *reservedMediaTableWithContext) InsertNoConflict(origin string, mediaId string, reason string) error { + _, err := s.statements.insertReservedMediaNoConflict.ExecContext(s.ctx, origin, mediaId, reason) return err } + +func (s *reservedMediaTableWithContext) IdExists(origin string, mediaId string) (bool, error) { + row := s.statements.selectReservedMediaExists.QueryRowContext(s.ctx, origin, mediaId) + val := false + err := row.Scan(&val) + if err == sql.ErrNoRows { + err = nil + val = false + } + return val, err +} diff --git a/database/table_thumbnails.go b/database/table_thumbnails.go index c431621b..67de1760 100644 --- a/database/table_thumbnails.go +++ b/database/table_thumbnails.go @@ -32,6 +32,7 @@ const selectThumbnailsForMedia = "SELECT origin, media_id, content_type, width, const selectOldThumbnails = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE sha256_hash IN (SELECT t2.sha256_hash FROM thumbnails AS t2 WHERE t2.creation_ts < $1);" const deleteThumbnail = "DELETE FROM thumbnails WHERE origin = $1 AND media_id = $2 AND content_type = $3 AND width = $4 AND height = $5 AND method = $6 AND animated = $7 AND sha256_hash = $8 AND size_bytes = $9 AND creation_ts = $10 AND datastore_id = $11 AND location = $11;" const updateThumbnailLocation = "UPDATE thumbnails SET datastore_id = $3, location = $4 WHERE datastore_id = $1 AND location = $2;" +const selectThumbnailsByLocation = "SELECT origin, media_id, content_type, width, height, method, animated, sha256_hash, size_bytes, creation_ts, datastore_id, location FROM thumbnails WHERE datastore_id = $1 AND location = $2;" type thumbnailsTableStatements struct { selectThumbnailByParams *sql.Stmt @@ -41,6 +42,7 @@ type thumbnailsTableStatements struct { selectOldThumbnails *sql.Stmt deleteThumbnail *sql.Stmt updateThumbnailLocation *sql.Stmt + selectThumbnailsByLocation *sql.Stmt } type thumbnailsTableWithContext struct { @@ -73,6 +75,9 @@ func prepareThumbnailsTables(db *sql.DB) (*thumbnailsTableStatements, error) { if stmts.updateThumbnailLocation, err = db.Prepare(updateThumbnailLocation); err != nil { return nil, errors.New("error preparing updateThumbnailLocation: " + err.Error()) } + if stmts.selectThumbnailsByLocation, err = db.Prepare(selectThumbnailsByLocation); err != nil { + return nil, errors.New("error preparing selectThumbnailsByLocation: " + err.Error()) + } return stmts, nil } @@ -95,9 +100,8 @@ func (s *thumbnailsTableWithContext) GetByParams(origin string, mediaId string, return val, err } -func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string) ([]*DbThumbnail, error) { +func (s *thumbnailsTableWithContext) scanRows(rows *sql.Rows, err error) ([]*DbThumbnail, error) { results := make([]*DbThumbnail, 0) - rows, err := s.statements.selectThumbnailsForMedia.QueryContext(s.ctx, origin, mediaId) if err != nil { if err == sql.ErrNoRows { return results, nil @@ -111,26 +115,20 @@ func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string) } results = append(results, val) } + return results, nil } +func (s *thumbnailsTableWithContext) GetForMedia(origin string, mediaId string) ([]*DbThumbnail, error) { + return s.scanRows(s.statements.selectThumbnailsForMedia.QueryContext(s.ctx, origin, mediaId)) +} + func (s *thumbnailsTableWithContext) GetOlderThan(ts int64) ([]*DbThumbnail, error) { - results := make([]*DbThumbnail, 0) - rows, err := s.statements.selectOldThumbnails.QueryContext(s.ctx, ts) - if err != nil { - if err == sql.ErrNoRows { - return results, nil - } - return nil, err - } - for rows.Next() { - val := &DbThumbnail{Locatable: &Locatable{}} - if err = rows.Scan(&val.Origin, &val.MediaId, &val.ContentType, &val.Width, &val.Height, &val.Method, &val.Animated, &val.Sha256Hash, &val.SizeBytes, &val.CreationTs, &val.DatastoreId, &val.Location); err != nil { - return nil, err - } - results = append(results, val) - } - return results, nil + return s.scanRows(s.statements.selectOldThumbnails.QueryContext(s.ctx, ts)) +} + +func (s *thumbnailsTableWithContext) GetByLocation(datastoreId string, location string) ([]*DbThumbnail, error) { + return s.scanRows(s.statements.selectThumbnailsByLocation.QueryContext(s.ctx, datastoreId, location)) } func (s *thumbnailsTableWithContext) Insert(record *DbThumbnail) error { diff --git a/migrations/23_add_datastore_locations_indexes_down.sql b/migrations/23_add_datastore_locations_indexes_down.sql new file mode 100644 index 00000000..fcdcec09 --- /dev/null +++ b/migrations/23_add_datastore_locations_indexes_down.sql @@ -0,0 +1,2 @@ +DROP INDEX IF EXISTS idx_datastore_id_location_thumbnails; +DROP INDEX IF EXISTS idx_datastore_id_location_media; diff --git a/migrations/23_add_datastore_locations_indexes_up.sql b/migrations/23_add_datastore_locations_indexes_up.sql new file mode 100644 index 00000000..3f9c9e12 --- /dev/null +++ b/migrations/23_add_datastore_locations_indexes_up.sql @@ -0,0 +1,2 @@ +CREATE INDEX IF NOT EXISTS idx_datastore_id_location_thumbnails ON thumbnails(datastore_id, location); +CREATE INDEX IF NOT EXISTS idx_datastore_id_location_media ON media(datastore_id, location); diff --git a/pipelines/_steps/upload/generate_media_id.go b/pipelines/_steps/upload/generate_media_id.go index 24bc90cd..4fd86792 100644 --- a/pipelines/_steps/upload/generate_media_id.go +++ b/pipelines/_steps/upload/generate_media_id.go @@ -15,6 +15,7 @@ func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error) } heldDb := database.GetInstance().HeldMedia.Prepare(ctx) mediaDb := database.GetInstance().Media.Prepare(ctx) + reservedDb := database.GetInstance().ReservedMedia.Prepare(ctx) var mediaId string var err error var exists bool @@ -41,6 +42,15 @@ func GenerateMediaId(ctx rcontext.RequestContext, origin string) (string, error) continue } + // Also check to see if the media ID is reserved due to a past action + exists, err = reservedDb.IdExists(origin, mediaId) + if err != nil { + return "", err + } + if exists { + continue + } + return mediaId, nil } return "", errors.New("internal limit reached: fell out of media ID generation loop") diff --git a/tasks/task_runner/purge.go b/tasks/task_runner/purge.go new file mode 100644 index 00000000..121130c0 --- /dev/null +++ b/tasks/task_runner/purge.go @@ -0,0 +1,227 @@ +package task_runner + +import ( + "errors" + "fmt" + + "github.com/turt2live/matrix-media-repo/common" + "github.com/turt2live/matrix-media-repo/common/rcontext" + "github.com/turt2live/matrix-media-repo/database" + "github.com/turt2live/matrix-media-repo/datastores" + "github.com/turt2live/matrix-media-repo/util" +) + +type purgeConfig struct { + IncludeQuarantined bool +} + +type PurgeAuthContext struct { + UploaderUserId string + SourceOrigin string +} + +func (c *PurgeAuthContext) canAffect(media *database.DbMedia) bool { + if c.UploaderUserId != "" && c.UploaderUserId != media.UserId { + return false + } + if c.SourceOrigin != "" && c.SourceOrigin != media.Origin { + return false + } + return true +} + +func PurgeMedia(ctx rcontext.RequestContext, authContext *PurgeAuthContext, toHandle *QuarantineThis) ([]string, error) { + records, err := resolveMedia(ctx, "", toHandle) + if err != nil { + return nil, err + } + + // Check auth on all records before actually processing them + for _, r := range records { + if !authContext.canAffect(r) { + return nil, common.ErrWrongUser + } + } + + // Now we process all the records + return doPurge(ctx, records, &purgeConfig{IncludeQuarantined: true}) +} + +func doPurge(ctx rcontext.RequestContext, records []*database.DbMedia, config *purgeConfig) ([]string, error) { + mediaDb := database.GetInstance().Media.Prepare(ctx) + thumbsDb := database.GetInstance().Thumbnails.Prepare(ctx) + attrsDb := database.GetInstance().MediaAttributes.Prepare(ctx) + reservedDb := database.GetInstance().ReservedMedia.Prepare(ctx) + + // Filter the records early on to remove things we're not going to handle + ctx.Log.Debug("Purge pre-filter") + records2 := make([]*database.DbMedia, 0) + for _, r := range records { + if r.Quarantined && !config.IncludeQuarantined { + continue // skip quarantined media so later loops don't try to purge it + } + attrs, err := attrsDb.Get(r.Origin, r.MediaId) + if err != nil { + return nil, err + } + if attrs != nil && attrs.Purpose == database.PurposePinned { + continue + } + + records2 = append(records2, r) + } + records = records2 + + flagMap := make(map[string]map[string]bool) // outer key = file location, inner key = MXC, value = in records[] + thumbsMap := make(map[string][]*database.DbThumbnail) + + // First, we identify all the media which is using the file references we think we want to delete + // This includes thumbnails (flagged under the original media MXC URI) + ctx.Log.Debug("Stage 1 of purge") + doFlagging := func(datastoreId string, location string) error { + locationId := fmt.Sprintf("%s/%s", datastoreId, location) + if _, ok := flagMap[locationId]; ok { + return nil // we already processed this file location - skip trying to populate from it + } + + flagMap[locationId] = make(map[string]bool) + + // Find media records first + media, err := mediaDb.GetByLocation(datastoreId, location) + if err != nil { + return err + } + for _, r2 := range media { + mxc := util.MxcUri(r2.Origin, r2.MediaId) + flagMap[locationId][mxc] = false + } + + // Now thumbnails + thumbs, err := thumbsDb.GetByLocation(datastoreId, location) + if err != nil { + return err + } + for _, r2 := range thumbs { + mxc := util.MxcUri(r2.Origin, r2.MediaId) + flagMap[locationId][mxc] = false + } + + return nil + } + for _, r := range records { + if err := doFlagging(r.DatastoreId, r.Location); err != nil { + return nil, err + } + + // We also grab all the thumbnails of the proposed media to clear those files out safely too + thumbs, err := thumbsDb.GetForMedia(r.Origin, r.MediaId) + if err != nil { + return nil, err + } + thumbsMap[util.MxcUri(r.Origin, r.MediaId)] = thumbs + for _, t := range thumbs { + if err = doFlagging(t.DatastoreId, t.Location); err != nil { + return nil, err + } + } + } + + // Next, we re-iterate to flag records as being deleted + ctx.Log.Debug("Stage 2 of purge") + markBeingPurged := func(locationId string, mxc string) error { + if m, ok := flagMap[locationId]; !ok { + return errors.New("logic error: missing flag map for location ID in second step") + } else { + if v, ok := m[mxc]; !ok { + return errors.New("logic error: missing flag map value for MXC URI in second step") + } else if !v { // if v is `true` then it's already been processed - skip a write step + m[mxc] = true + } + } + + return nil + } + for _, r := range records { + locationId := fmt.Sprintf("%s/%s", r.DatastoreId, r.Location) + mxc := util.MxcUri(r.Origin, r.MediaId) + if err := markBeingPurged(locationId, mxc); err != nil { + return nil, err + } + + // Mark the thumbnails too + if thumbs, ok := thumbsMap[mxc]; !ok { + return nil, errors.New("logic error: missing thumbnails map value for MXC URI in second step") + } else { + for _, t := range thumbs { + locationId = fmt.Sprintf("%s/%s", t.DatastoreId, t.Location) + mxc = util.MxcUri(t.Origin, t.MediaId) + if err := markBeingPurged(locationId, mxc); err != nil { + return nil, err + } + } + } + } + + // Finally, we can run through the records and start deleting media that's safe to delete + ctx.Log.Debug("Stage 3 of purge") + deletedLocations := make(map[string]bool) + removedMxcs := make([]string, 0) + tryRemoveDsFile := func(datastoreId string, location string) error { + locationId := fmt.Sprintf("%s/%s", datastoreId, location) + if _, ok := deletedLocations[locationId]; ok { + return nil // already deleted/handled + } + if m, ok := flagMap[locationId]; !ok { + return errors.New("logic error: missing flag map value for location ID in third step") + } else { + for _, b := range m { + if !b { + return nil // unsafe to delete, but no error + } + } + } + + // Try deleting the file + err := datastores.RemoveWithDsId(ctx, datastoreId, location) + if err != nil { + return err + } + deletedLocations[locationId] = true + return nil + } + for _, r := range records { + mxc := util.MxcUri(r.Origin, r.MediaId) + + if err := tryRemoveDsFile(r.DatastoreId, r.Location); err != nil { + return nil, err + } + if util.IsServerOurs(r.Origin) { + if err := reservedDb.InsertNoConflict(r.Origin, r.MediaId, "purged / deleted"); err != nil { + return nil, err + } + } + if !r.Quarantined { // keep quarantined flag + if err := mediaDb.Delete(r.Origin, r.MediaId); err != nil { + return nil, err + } + } + removedMxcs = append(removedMxcs, mxc) + + // Remove the thumbnails too + if thumbs, ok := thumbsMap[mxc]; !ok { + return nil, errors.New("logic error: missing thumbnails for MXC URI in third step") + } else { + for _, t := range thumbs { + if err := tryRemoveDsFile(t.DatastoreId, t.Location); err != nil { + return nil, err + } + if err := thumbsDb.Delete(t); err != nil { + return nil, err + } + } + } + } + + // Finally, we're done + return removedMxcs, nil +} diff --git a/tasks/task_runner/purge_remote_media.go b/tasks/task_runner/purge_remote_media.go index 24f9493c..97ee5cf0 100644 --- a/tasks/task_runner/purge_remote_media.go +++ b/tasks/task_runner/purge_remote_media.go @@ -1,13 +1,10 @@ package task_runner import ( - "fmt" - "github.com/getsentry/sentry-go" "github.com/turt2live/matrix-media-repo/common/config" "github.com/turt2live/matrix-media-repo/common/rcontext" "github.com/turt2live/matrix-media-repo/database" - "github.com/turt2live/matrix-media-repo/datastores" "github.com/turt2live/matrix-media-repo/util" ) @@ -29,7 +26,6 @@ func PurgeRemoteMedia(ctx rcontext.RequestContext) { // PurgeRemoteMediaBefore returns (count affected, error) func PurgeRemoteMediaBefore(ctx rcontext.RequestContext, beforeTs int64) (int, error) { mediaDb := database.GetInstance().Media.Prepare(ctx) - thumbsDb := database.GetInstance().Thumbnails.Prepare(ctx) origins := util.GetOurDomains() @@ -38,47 +34,10 @@ func PurgeRemoteMediaBefore(ctx rcontext.RequestContext, beforeTs int64) (int, e return 0, err } - removed := 0 - deletedLocations := make(map[string]bool) - for _, record := range records { - mxc := util.MxcUri(record.Origin, record.MediaId) - if record.Quarantined { - ctx.Log.Debugf("Skipping quarantined media %s", mxc) - continue // skip quarantined media - } - - if exists, err := thumbsDb.LocationExists(record.DatastoreId, record.Location); err != nil { - ctx.Log.Error("Error checking for conflicting thumbnail: ", err) - sentry.CaptureException(err) - } else if !exists { // if exists, skip - locationId := fmt.Sprintf("%s/%s", record.DatastoreId, record.Location) - if _, ok := deletedLocations[locationId]; !ok { - ctx.Log.Debugf("Trying to remove datastore object for %s", mxc) - err = datastores.RemoveWithDsId(ctx, record.DatastoreId, record.Location) - if err != nil { - ctx.Log.Error("Error deleting media from datastore: ", err) - sentry.CaptureException(err) - continue - } - deletedLocations[locationId] = true - } - ctx.Log.Debugf("Trying to database record for %s", mxc) - if err = mediaDb.Delete(record.Origin, record.MediaId); err != nil { - ctx.Log.Error("Error deleting thumbnail record: ", err) - sentry.CaptureException(err) - } - removed = removed + 1 - - thumbs, err := thumbsDb.GetForMedia(record.Origin, record.MediaId) - if err != nil { - ctx.Log.Warn("Error getting thumbnails for media: ", err) - sentry.CaptureException(err) - continue - } - - doPurgeThumbnails(ctx, thumbs) - } + removed, err := doPurge(ctx, records, &purgeConfig{IncludeQuarantined: false}) + if err != nil { + return 0, err } - return removed, nil + return len(removed), nil }