Skip to content

Commit

Permalink
fix potential memory leak (#395)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz authored Mar 12, 2022
1 parent 5779b01 commit 987b5c8
Show file tree
Hide file tree
Showing 12 changed files with 87 additions and 87 deletions.
2 changes: 1 addition & 1 deletion base/search/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func TestHNSW_InnerProduct(t *testing.T) {
idx, recall := builder.Build(0.9, 5, false)
assert.Greater(t, recall, float32(0.9))
recall = builder.evaluateTermSearch(idx, true, "prime")
assert.Greater(t, recall, float32(0.9))
assert.Greater(t, recall, float32(0.85))
}

func TestIVF_Cosine(t *testing.T) {
Expand Down
40 changes: 21 additions & 19 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import (
"github.com/zhenghaoz/gorse/storage/cache"
"github.com/zhenghaoz/gorse/storage/data"
"go.uber.org/zap"
"modernc.org/mathutil"
"math"
"modernc.org/sortutil"
"sort"
"time"
Expand Down Expand Up @@ -69,11 +69,8 @@ func (m *Master) runLoadDatasetTask() error {

// save popular items to cache
for category, items := range popularItems {
for batchBegin := 0; batchBegin < len(items); batchBegin += batchSize {
batchEnd := mathutil.Min(len(items), batchBegin+batchSize)
if err = m.CacheClient.AddSorted(cache.Key(cache.PopularItems, category), items[batchBegin:batchEnd]); err != nil {
base.Logger().Error("failed to cache popular items", zap.Error(err))
}
if err = m.CacheClient.SetSorted(cache.Key(cache.PopularItems, category), items); err != nil {
base.Logger().Error("failed to cache popular items", zap.Error(err))
}
}
if err = m.CacheClient.SetTime(cache.GlobalMeta, cache.LastUpdatePopularItemsTime, time.Now()); err != nil {
Expand All @@ -85,6 +82,13 @@ func (m *Master) runLoadDatasetTask() error {
if err = m.CacheClient.AddSorted(cache.Key(cache.LatestItems, category), items); err != nil {
base.Logger().Error("failed to cache latest items", zap.Error(err))
}
// reclaim outdated items
if len(items) > 0 {
threshold := items[len(items)-1].Score - 1
if err = m.CacheClient.RemSortedByScore(cache.Key(cache.LatestItems, category), math.Inf(-1), threshold); err != nil {
base.Logger().Error("failed to reclaim outdated items", zap.Error(err))
}
}
}
if err = m.CacheClient.SetTime(cache.GlobalMeta, cache.LastUpdateLatestItemsTime, time.Now()); err != nil {
base.Logger().Error("failed to write latest update latest items time", zap.Error(err))
Expand Down Expand Up @@ -124,12 +128,6 @@ func (m *Master) runLoadDatasetTask() error {
if err = m.CacheClient.SetSet(cache.ItemCategories, rankingDataset.CategorySet.List()...); err != nil {
base.Logger().Error("failed to write categories to cache", zap.Error(err))
}
for i, categories := range rankingDataset.ItemCategories {
itemId := rankingDataset.ItemIndex.ToName(int32(i))
if err = m.CacheClient.SetSet(cache.Key(cache.ItemCategories, itemId), categories...); err != nil {
base.Logger().Error("failed to write categories to cache", zap.Error(err))
}
}

// split ranking dataset
m.rankingModelMutex.Lock()
Expand Down Expand Up @@ -1208,18 +1206,22 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes,
}

// collect popular items
popularItems = make(map[string][]cache.Scored)
popularItemFilters := make(map[string]*heap.TopKStringFilter)
popularItemFilters[""] = heap.NewTopKStringFilter(m.GorseConfig.Database.CacheSize)
for itemIndex, val := range popularCount {
popularItems[""] = append(popularItems[""], cache.Scored{Id: rankingDataset.ItemIndex.ToName(int32(itemIndex)), Score: float64(val)})
itemId := rankingDataset.ItemIndex.ToName(int32(itemIndex))
popularItemFilters[""].Push(itemId, float64(val))
for _, category := range rankingDataset.ItemCategories[itemIndex] {
if _, exist := popularItems[category]; !exist {
popularItems[category] = make([]cache.Scored, 0)
if _, exist := popularItemFilters[category]; !exist {
popularItemFilters[category] = heap.NewTopKStringFilter(m.GorseConfig.Database.CacheSize)
}
popularItems[category] = append(popularItems[category], cache.Scored{Id: rankingDataset.ItemIndex.ToName(int32(itemIndex)), Score: float64(val)})
popularItemFilters[category].Push(itemId, float64(val))
}
}
for _, items := range popularItems {
cache.SortScores(items)
popularItems = make(map[string][]cache.Scored)
for category, popularItemFilter := range popularItemFilters {
items, scores := popularItemFilter.PopAll()
popularItems[category] = cache.CreateScoredItems(items, scores)
}

m.taskMonitor.Finish(TaskLoadDataset)
Expand Down
4 changes: 0 additions & 4 deletions master/tasks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,4 @@ func TestMaster_LoadDataFromDatabase(t *testing.T) {
categories, err := m.CacheClient.GetSet(cache.ItemCategories)
assert.NoError(t, err)
assert.Equal(t, []string{"0", "1", "2"}, categories)
categories, err = m.CacheClient.GetSet(cache.Key(cache.ItemCategories, "2"))
assert.NoError(t, err)
assert.Equal(t, []string{"2"}, categories)

}
69 changes: 28 additions & 41 deletions server/rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ type recommendContext struct {

func (s *RestServer) createRecommendContext(userId, category string, n int) (*recommendContext, error) {
// pull ignored items
ignoreItems, err := s.CacheClient.GetSortedByScore(cache.Key(cache.IgnoreItems, userId), math.Inf(-1), float64(float32(time.Now().Unix())))
ignoreItems, err := s.CacheClient.GetSortedByScore(cache.Key(cache.IgnoreItems, userId), math.Inf(-1), float64(time.Now().Unix()))
if err != nil {
return nil, errors.Trace(err)
}
Expand Down Expand Up @@ -770,11 +770,11 @@ func (s *RestServer) RecommendUserBased(ctx *recommendContext) error {
// add unseen items
for _, feedback := range feedbacks {
if !ctx.excludeSet.Has(feedback.ItemId) {
categories, err := s.CacheClient.GetSet(cache.Key(cache.ItemCategories, feedback.ItemId))
item, err := s.DataClient.GetItem(feedback.ItemId)
if err != nil {
return errors.Trace(err)
}
if ctx.category == "" || funk.ContainsString(categories, ctx.category) {
if ctx.category == "" || funk.ContainsString(item.Categories, ctx.category) {
candidates[feedback.ItemId] += user.Score
}
}
Expand Down Expand Up @@ -1153,6 +1153,10 @@ func (s *RestServer) batchInsertItems(response *restful.Response, temp []Item) {
})
}
}
if err = s.deleteItemFromLatestPopularCache(item.ItemId, false); err != nil {
InternalServerError(response, err)
return
}
count++
}
err := s.DataClient.BatchInsertItems(items)
Expand All @@ -1170,14 +1174,6 @@ func (s *RestServer) batchInsertItems(response *restful.Response, temp []Item) {
InternalServerError(response, err)
return
}
if err = s.deleteItemFromLatestPopularCache(item.ItemId, false); err != nil {
InternalServerError(response, err)
return
}
if err = s.CacheClient.AddSet(cache.Key(cache.ItemCategories, item.ItemId), item.Categories...); err != nil {
InternalServerError(response, err)
return
}
}
// insert timestamp score
for category, score := range timeScores {
Expand Down Expand Up @@ -1225,6 +1221,13 @@ func (s *RestServer) modifyItem(request *restful.Request, response *restful.Resp
BadRequest(response, err)
return
}
// refresh category cache
if patch.Categories != nil {
if err := s.deleteItemFromLatestPopularCache(itemId, false); err != nil {
InternalServerError(response, err)
return
}
}
if err := s.DataClient.ModifyItem(itemId, patch); err != nil {
InternalServerError(response, err)
return
Expand All @@ -1242,17 +1245,6 @@ func (s *RestServer) modifyItem(request *restful.Request, response *restful.Resp
return
}
}
// refresh category cache
if patch.Categories != nil {
if err := s.deleteItemFromLatestPopularCache(itemId, false); err != nil {
InternalServerError(response, err)
return
}
if err := s.CacheClient.AddSet(cache.Key(cache.ItemCategories, itemId), patch.Categories...); err != nil {
InternalServerError(response, err)
return
}
}
// insert new timestamp to the latest scores
if patch.Timestamp != nil || patch.Categories != nil {
item, err := s.DataClient.GetItem(itemId)
Expand Down Expand Up @@ -1320,17 +1312,17 @@ func (s *RestServer) getItem(request *restful.Request, response *restful.Respons

func (s *RestServer) deleteItem(request *restful.Request, response *restful.Response) {
itemId := request.PathParameter("item-id")
if err := s.DataClient.DeleteItem(itemId); err != nil {
// delete items from latest and popular
if err := s.deleteItemFromLatestPopularCache(itemId, true); err != nil {
InternalServerError(response, err)
return
}
// insert deleted item to cache
if err := s.CacheClient.SetInt(cache.HiddenItems, itemId, 1); err != nil {
if err := s.DataClient.DeleteItem(itemId); err != nil {
InternalServerError(response, err)
return
}
// delete items from latest and popular
if err := s.deleteItemFromLatestPopularCache(itemId, true); err != nil {
// insert deleted item to cache
if err := s.CacheClient.SetInt(cache.HiddenItems, itemId, 1); err != nil {
InternalServerError(response, err)
return
}
Expand All @@ -1342,10 +1334,15 @@ func (s *RestServer) deleteItemFromLatestPopularCache(itemId string, deleteItem
if deleteItem {
deleteKeys = []string{cache.LatestItems, cache.PopularItems}
}
if categories, err := s.CacheClient.GetSet(cache.Key(cache.ItemCategories, itemId)); err != nil {
return err
if item, err := s.DataClient.GetItem(itemId); err != nil {
if errors.IsNotFound(err) {
// do nothing if the item doesn't exist
return nil
} else {
return err
}
} else {
for _, category := range categories {
for _, category := range item.Categories {
deleteKeys = append(deleteKeys, cache.Key(cache.LatestItems, category))
deleteKeys = append(deleteKeys, cache.Key(cache.PopularItems, category))
}
Expand All @@ -1355,7 +1352,7 @@ func (s *RestServer) deleteItemFromLatestPopularCache(itemId string, deleteItem
return err
}
}
return s.CacheClient.Delete(cache.ItemCategories, itemId)
return nil
}

func (s *RestServer) insertItemCategory(request *restful.Request, response *restful.Response) {
Expand Down Expand Up @@ -1389,11 +1386,6 @@ func (s *RestServer) insertItemCategory(request *restful.Request, response *rest
InternalServerError(response, err)
return
}
// add category to cache
if err = s.CacheClient.AddSet(cache.Key(cache.ItemCategories, itemId), category); err != nil {
InternalServerError(response, err)
return
}
Ok(response, Success{RowAffected: 1})
}

Expand Down Expand Up @@ -1429,11 +1421,6 @@ func (s *RestServer) deleteItemCategory(request *restful.Request, response *rest
InternalServerError(response, err)
return
}
// remove category from cache
if err = s.CacheClient.RemSet(cache.Key(cache.ItemCategories, itemId), category); err != nil {
InternalServerError(response, err)
return
}
Ok(response, Success{RowAffected: 1})
}

Expand Down
16 changes: 0 additions & 16 deletions server/rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,6 @@ func TestServer_Items(t *testing.T) {
categories, err := s.CacheClient.GetSet(cache.ItemCategories)
assert.NoError(t, err)
assert.Equal(t, []string{"*"}, categories)
categories, err = s.CacheClient.GetSet(cache.Key(cache.ItemCategories, "2"))
assert.NoError(t, err)
assert.Equal(t, []string{"*"}, categories)

// delete item
apitest.New().
Expand All @@ -342,9 +339,6 @@ func TestServer_Items(t *testing.T) {
isHidden, err := s.CacheClient.GetInt(cache.HiddenItems, "6")
assert.NoError(t, err)
assert.Equal(t, 1, isHidden)
categories, err = s.CacheClient.GetSet(cache.Key(cache.ItemCategories, "6"))
assert.NoError(t, err)
assert.Empty(t, categories)
// get latest items
apitest.New().
Handler(s.handler).
Expand Down Expand Up @@ -525,9 +519,6 @@ func TestServer_Items(t *testing.T) {
Timestamp: timestamp,
})).
End()
categories, err = s.CacheClient.GetSet(cache.Key(cache.ItemCategories, "2"))
assert.NoError(t, err)
assert.Equal(t, []string{"-", "@"}, categories)
// get latest items
apitest.New().
Handler(s.handler).
Expand Down Expand Up @@ -581,9 +572,6 @@ func TestServer_Items(t *testing.T) {
Timestamp: timestamp,
})).
End()
categories, err = s.CacheClient.GetSet(cache.Key(cache.ItemCategories, "2"))
assert.NoError(t, err)
assert.Equal(t, []string{"-"}, categories)
// get latest items
apitest.New().
Handler(s.handler).
Expand Down Expand Up @@ -1244,10 +1232,6 @@ func TestServer_GetRecommends_Fallback_UserBasedSimilar(t *testing.T) {
{ItemId: "48", Categories: []string{"*"}},
})
assert.NoError(t, err)
err = s.CacheClient.AddSet(cache.Key(cache.ItemCategories, "12"), "*")
assert.NoError(t, err)
err = s.CacheClient.AddSet(cache.Key(cache.ItemCategories, "48"), "*")
assert.NoError(t, err)
// test fallback
s.GorseConfig.Recommend.FallbackRecommend = []string{"user_based"}
apitest.New().
Expand Down
2 changes: 1 addition & 1 deletion storage/cache/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ const (

// ItemCategories is the set of item categories. The format of key:
// Global item categories - item_categories
// Categories of an item - item_categories/{item_id}
ItemCategories = "item_categories"

LastModifyItemTime = "last_modify_item_time" // the latest timestamp that a user related data was modified
Expand Down Expand Up @@ -189,6 +188,7 @@ type Database interface {
GetSortedScore(key, member string) (float64, error)
GetSorted(key string, begin, end int) ([]Scored, error)
GetSortedByScore(key string, begin, end float64) ([]Scored, error)
RemSortedByScore(key string, begin, end float64) error
AddSorted(key string, scores []Scored) error
SetSorted(key string, scores []Scored) error
IncrSorted(key, member string) error
Expand Down
12 changes: 12 additions & 0 deletions storage/cache/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package cache

import (
"github.com/juju/errors"
"math"
"testing"
"time"

Expand Down Expand Up @@ -130,6 +131,17 @@ func testSort(t *testing.T, db Database) {
{"2", 1.2},
{"3", 1.3},
}, partItems)
// remove scores by score
err = db.AddSorted("sort", []Scored{
{"5", -5},
{"6", -6},
})
assert.NoError(t, err)
err = db.RemSortedByScore("sort", math.Inf(-1), -1)
assert.NoError(t, err)
partItems, err = db.GetSortedByScore("sort", math.Inf(-1), -1)
assert.NoError(t, err)
assert.Empty(t, partItems)
// Increase score
err = db.IncrSorted("sort", "0")
assert.NoError(t, err)
Expand Down
5 changes: 5 additions & 0 deletions storage/cache/no_database.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ func (NoDatabase) GetSortedByScore(_ string, _, _ float64) ([]Scored, error) {
return nil, ErrNoDatabase
}

// RemSortedByScore method of NoDatabase returns ErrNoDatabase.
func (NoDatabase) RemSortedByScore(_ string, _, _ float64) error {
return ErrNoDatabase
}

// AddSorted method of NoDatabase returns ErrNoDatabase.
func (NoDatabase) AddSorted(_ string, _ []Scored) error {
return ErrNoDatabase
Expand Down
2 changes: 2 additions & 0 deletions storage/cache/no_database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func TestNoDatabase(t *testing.T) {
assert.ErrorIs(t, err, ErrNoDatabase)
_, err = database.GetSortedByScore("", 0, 0)
assert.ErrorIs(t, err, ErrNoDatabase)
err = database.RemSortedByScore("", 0, 0)
assert.ErrorIs(t, err, ErrNoDatabase)
err = database.SetSorted("", nil)
assert.ErrorIs(t, err, ErrNoDatabase)
err = database.AddSorted("", nil)
Expand Down
8 changes: 8 additions & 0 deletions storage/cache/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ func (r *Redis) GetSortedByScore(key string, begin, end float64) ([]Scored, erro
return results, nil
}

func (r *Redis) RemSortedByScore(key string, begin, end float64) error {
ctx := context.Background()
return r.client.ZRemRangeByScore(ctx, key,
strconv.FormatFloat(begin, 'g', -1, 64),
strconv.FormatFloat(end, 'g', -1, 64)).
Err()
}

// AddSorted add scores to sorted set.
func (r *Redis) AddSorted(key string, scores []Scored) error {
if len(scores) == 0 {
Expand Down
Loading

0 comments on commit 987b5c8

Please sign in to comment.