diff --git a/digestwriter/consumer.go b/digestwriter/consumer.go index 8904ae47..42dac81e 100644 --- a/digestwriter/consumer.go +++ b/digestwriter/consumer.go @@ -88,7 +88,7 @@ type DigestConsumer struct { // NewConsumer constructs a new instance of Consumer interface // specialized in consuming from SHA extractor's result topic func NewConsumer(storage Storage) (*utils.KafkaConsumer, error) { - setupLogger() + SetupLogger() processor := DigestConsumer{ storage, 0, diff --git a/digestwriter/digestwriter.go b/digestwriter/digestwriter.go index 0cdbca9d..edc53e37 100644 --- a/digestwriter/digestwriter.go +++ b/digestwriter/digestwriter.go @@ -23,7 +23,7 @@ const ( ExitStatusConsumerError ) -func setupLogger() { +func SetupLogger() { if logger == nil { var err error logger, err = utils.CreateLogger(utils.Cfg.LoggingLevel) @@ -49,7 +49,7 @@ func startConsumer(storage Storage) (*utils.KafkaConsumer, error) { // Start function tries to start the digest writer service. func Start() { - setupLogger() + SetupLogger() logger.Infoln("Initializing digest writer...") RunMetrics() diff --git a/digestwriter/export_test.go b/digestwriter/export_test.go index 8c771a1c..1a09ac66 100644 --- a/digestwriter/export_test.go +++ b/digestwriter/export_test.go @@ -16,7 +16,6 @@ var ( // functions from consumer.go source file ExtractDigestsFromMessage = extractDigestsFromMessage ParseMessage = parseMessage - SetupLogger = setupLogger ) // kafka-related functions diff --git a/digestwriter/storage.go b/digestwriter/storage.go index cabbfa11..9322eab7 100644 --- a/digestwriter/storage.go +++ b/digestwriter/storage.go @@ -85,7 +85,7 @@ func prepareClusterImageLists(clusterID int64, currentImageIDs map[int64]struct{ } // updateClusterCache updates the cache section of cluster row in db -func (storage *DBStorage) updateClusterCache(tx *gorm.DB, clusterID int64, existingDigests []models.Image) error { +func (storage *DBStorage) UpdateClusterCache(tx *gorm.DB, clusterID int64, existingDigests []models.Image) error { digestIDs := make([]int64, 0, len(existingDigests)) for _, digest := range existingDigests { digestIDs = append(digestIDs, digest.ID) @@ -177,13 +177,15 @@ func (storage *DBStorage) linkDigestsToCluster(tx *gorm.DB, clusterStr string, c } } - err := storage.updateClusterCache(tx, clusterID, existingDigests) - if err != nil { - logger.WithFields(logrus.Fields{ - errorKey: err.Error(), - clusterIDKey: clusterID, - }).Errorln("couldn't update cluster cve cache") - return err + if len(toInsert) > 0 || len(toDelete) > 0 { + err := storage.UpdateClusterCache(tx, clusterID, existingDigests) + if err != nil { + logger.WithFields(logrus.Fields{ + errorKey: err.Error(), + clusterIDKey: clusterID, + }).Errorln("couldn't update cluster cve cache") + return err + } } logger.Debugln("linked digests to cluster successfully") diff --git a/test/setup.go b/test/setup.go index 4e431de1..66f2ecc7 100644 --- a/test/setup.go +++ b/test/setup.go @@ -1,6 +1,8 @@ package test import ( + "app/base/models" + "app/digestwriter" "fmt" "os" "path/filepath" @@ -69,6 +71,38 @@ func ReverseWalkFindFile(filename string) (string, error) { return datapath, nil } +func PopulateClusterCveCache(DB *gorm.DB) error { + digestwriter.SetupLogger() + storage, err := digestwriter.NewStorage() + if err != nil { + return err + } + + clusters := []models.Cluster{} + if res := DB.Find(&clusters); res.Error != nil { + return res.Error + } + + for _, cluster := range clusters { + clusterDigests := []models.Image{} + subq := DB.Table("cluster_image"). + Joins("JOIN image ON cluster_image.image_id = image.id"). + Where("cluster_image.cluster_id = ?", cluster.ID) + if res := DB.Joins("JOIN (?) AS cluster_image ON image.id = cluster_image.image_id", subq). + Find(&clusterDigests); res.Error != nil { + return res.Error + } + if err := storage.UpdateClusterCache(DB, cluster.ID, clusterDigests); err != nil { + return err + } + } + return nil +} + +func PopulateCaches(DB *gorm.DB) error { + return PopulateClusterCveCache(DB) +} + func ResetDB() error { if testingDataPath == "" { var err error @@ -94,5 +128,8 @@ func ResetDB() error { return err } _, err = plainDb.Exec(string(buf)) - return err + if err != nil { + return err + } + return PopulateCaches(DB) }