diff --git a/go.mod b/go.mod index adaf1d73..6595f4e1 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-plugin v1.5.2 // indirect github.com/hashicorp/go-version v1.6.0 - github.com/joncrlsn/dque v0.0.0-20211108142734-c2ef48c5192a github.com/jstemmer/go-junit-report/v2 v2.1.0 github.com/mattn/go-isatty v0.0.20 github.com/mitchellh/mapstructure v1.5.0 @@ -322,7 +321,7 @@ require ( github.com/sivchari/tenv v1.7.1 // indirect github.com/sonatard/noctx v0.0.2 // indirect github.com/sourcegraph/go-diff v0.7.0 // indirect - github.com/spf13/afero v1.10.0 // indirect + github.com/spf13/afero v1.10.0 github.com/spf13/cast v1.5.1 // indirect github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect diff --git a/go.sum b/go.sum index c8784b56..87c47773 100644 --- a/go.sum +++ b/go.sum @@ -351,7 +351,6 @@ github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJA github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2 h1:ZpnhV/YsD2/4cESfV5+Hoeu/iUR3ruzNvZ+yQfO03a0= github.com/godbus/dbus v0.0.0-20190726142602-4481cbc300e2/go.mod h1:bBOAhwG1umN6/6ZUMtDFBMQR8jRg9O75tm9K00oMsK4= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofrs/flock v0.7.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= @@ -548,8 +547,6 @@ github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9Y github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= -github.com/joncrlsn/dque v0.0.0-20211108142734-c2ef48c5192a h1:sfe532Ipn7GX0V6mHdynBk393rDmqgI0QmjLK7ct7TU= -github.com/joncrlsn/dque v0.0.0-20211108142734-c2ef48c5192a/go.mod h1:dNKs71rs2VJGBAmttu7fouEsRQlRjxy0p1Sx+T5wbpY= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -589,7 +586,6 @@ github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= diff --git a/policy/scan/disk_queue.go b/policy/scan/disk_queue.go index e926ff7a..0adf83ae 100644 --- a/policy/scan/disk_queue.go +++ b/policy/scan/disk_queue.go @@ -9,23 +9,23 @@ import ( "sync" "time" - "github.com/joncrlsn/dque" "github.com/rs/zerolog/log" + "go.mondoo.com/cnspec/v9/policy/scan/pdque" "google.golang.org/protobuf/proto" ) type diskQueueConfig struct { - dir string - filename string - segmentSize int - sync bool + dir string + filename string + maxSize int + sync bool } var defaultDqueConfig = diskQueueConfig{ - dir: "/tmp/cnspec-queue", // TODO: consider configurable path - filename: "disk-queue", - segmentSize: 500, - sync: false, + dir: "/tmp/cnspec-queue", // TODO: consider configurable path + filename: "disk-queue", + maxSize: 500, + sync: false, } // queueMsg is the being stored in disk queue @@ -40,7 +40,7 @@ type queuePayload struct { } type diskQueueClient struct { - queue *dque.DQue + queue *pdque.Queue once sync.Once wg sync.WaitGroup entries chan Job @@ -68,15 +68,11 @@ func newDqueClient(config diskQueueConfig, handler func(job *Job)) (*diskQueueCl return nil, fmt.Errorf("cannot create queue directory: %s", err) } - q.queue, err = dque.NewOrOpen(config.filename, config.dir, config.segmentSize, diskQueueEntryBuilder) + q.queue, err = pdque.NewOrOpen(config.filename, config.dir, config.maxSize, diskQueueEntryBuilder) if err != nil { return nil, err } - if !config.sync { - _ = q.queue.TurboOn() - } - q.entries = make(chan Job) q.wg.Add(2) @@ -127,7 +123,7 @@ func (c *diskQueueClient) popper() { entry, err := c.queue.DequeueBlock() if err != nil { switch err { - case dque.ErrQueueClosed: + case pdque.ErrQueueClosed: return default: log.Error().Err(err).Msg("could not pop job from disk queue") diff --git a/policy/scan/disk_queue_test.go b/policy/scan/disk_queue_test.go new file mode 100644 index 00000000..088ebf67 --- /dev/null +++ b/policy/scan/disk_queue_test.go @@ -0,0 +1,71 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package scan + +import ( + "os" + "testing" + + "go.mondoo.com/cnquery/v9/providers-sdk/v1/inventory" +) + +func TestDiskQueueClient_EnqueueDequeue(t *testing.T) { + tempDir, err := os.MkdirTemp("", "testdir") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tempDir) // Clean up + + // Update the configuration to use the temporary directory + testConfig := defaultDqueConfig + testConfig.dir = tempDir + + completionChannel := make(chan struct{}, 50) // Channel to signal job completion + + handler := func(job *Job) { + completionChannel <- struct{}{} // Signal completion + } + + client, err := newDqueClient(testConfig, handler) + if err != nil { + t.Fatalf("Failed to create diskQueueClient: %v", err) + } + defer client.Stop() + + // Test Enqueue + testJob := &Job{ + Inventory: &inventory.Inventory{ + Spec: &inventory.InventorySpec{ + Assets: []*inventory.Asset{ + { + Connections: []*inventory.Config{ + { + Type: "k8s", + Options: map[string]string{ + "path": "./testdata/2pods.yaml", + }, + Discover: &inventory.Discovery{ + Targets: []string{"auto"}, + }, + }, + }, + ManagedBy: "mondoo-operator-123", + }, + }, + }, + }, + } + for i := 0; i < 50; i++ { + client.Channel() <- *testJob + } + + for i := 0; i < 50; i++ { + <-completionChannel + } + + // Verify that all jobs have been processed + if len(completionChannel) != 0 { + t.Errorf("Expected handler to be called 50 times, but was called %d times", 50-len(completionChannel)) + } +} diff --git a/policy/scan/pdque/queue.go b/policy/scan/pdque/queue.go new file mode 100644 index 00000000..94ef0afc --- /dev/null +++ b/policy/scan/pdque/queue.go @@ -0,0 +1,359 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package pdque + +import ( + "bytes" + "encoding/gob" + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/spf13/afero" +) + +const ( + jobFileExt = ".job" + tempFileExt = ".tmp" + tempFilePrefix = "." +) + +var ErrQueueClosed = errors.New("queue is closed") +var ErrNoJobs = errors.New("no jobs in queue") +var ErrQueueFull = errors.New("queue is full") + +// ErrUnableToDecode is returned when an object cannot be decoded. +type ErrUnableToDecode struct { + Path string + Err error +} + +func (e ErrUnableToDecode) Error() string { + return fmt.Sprintf("object in file %s cannot be decoded: %s", e.Path, e.Err) +} + +type Queue struct { + Name string + path string + mu sync.Mutex + closed bool + cond *sync.Cond + builder func() interface{} + maxSize int +} + +func New(name string, path string, maxSize int, builder func() interface{}) (*Queue, error) { + err := os.MkdirAll(path, 0o755) + if err != nil { + return nil, err + } + overlyPermissive, err := isOverlyPermissive(path) + if err != nil { + return nil, err + } + if overlyPermissive { + return nil, errors.New("path is overly permissive, make sure it is not writable to others or the group: " + path) + } + + que := &Queue{ + Name: name, + path: path, + builder: builder, + maxSize: maxSize, + } + que.cond = sync.NewCond(&que.mu) + + return que, nil +} + +func NewOrOpen(name string, path string, maxSize int, builder func() interface{}) (*Queue, error) { + var que *Queue + _, err := os.Stat(path) + if os.IsNotExist(err) { + que, err = New(name, path, maxSize, builder) + if err != nil { + return nil, err + } + } else { + que, err = Open(name, path, maxSize, builder) + if err != nil { + return nil, err + } + } + + return que, nil +} + +func Open(name string, path string, maxSize int, builder func() interface{}) (*Queue, error) { + overlyPermissive, err := isOverlyPermissive(path) + if err != nil { + return nil, err + } + if overlyPermissive { + return nil, errors.New("path is overly permissive, make sure it is not writable to others or the group: " + path) + } + + que := &Queue{ + Name: name, + path: path, + builder: builder, + maxSize: maxSize, + } + que.cond = sync.NewCond(&que.mu) + + return que, nil +} + +// Close safely shuts down the queue, ensuring all resources are released. +func (q *Queue) Close() error { + q.mu.Lock() + defer q.mu.Unlock() + + // If the queue is already closed, return an error or just exit. + if q.closed { + return ErrQueueClosed + } + + // Clean up temporary files. + files, err := os.ReadDir(q.path) + if err != nil { + return err + } + for _, file := range files { + if strings.HasPrefix(file.Name(), tempFilePrefix) { + err := os.Remove(filepath.Join(q.path, file.Name())) + if err != nil { + return err + } + } + } + + // Set the queue as closed to prevent further operations. + q.closed = true + + // Wake up all goroutines waiting on the condition variable before closing the queue. + q.cond.Broadcast() + + return nil +} + +func (q *Queue) Enqueue(obj interface{}) error { + q.mu.Lock() + defer q.mu.Unlock() + + if q.closed { + return ErrQueueClosed + } + + // Check if the queue size has reached the maxSize + if q.maxSize > 0 { + size, err := q.currentSize() + if err != nil { + return err // Handle or return the error + } + if size >= q.maxSize { + return ErrQueueFull // Or handle as needed + } + } + + // Find the next available filename + filename, err := q.nextAvailableFilename() + if err != nil { + return err + } + + tempPath := filepath.Join(q.path, tempFilePrefix+filename+tempFileExt) + finalPath := filepath.Join(q.path, filename+jobFileExt) + + // Encode the struct to a byte buffer + var buff bytes.Buffer + enc := gob.NewEncoder(&buff) + if err := enc.Encode(obj); err != nil { + return err + } + + // Write to a temporary file + err = os.WriteFile(tempPath, buff.Bytes(), 0o644) + if err != nil { + return err + } + + // Rename the temporary file to its final name + err = os.Rename(tempPath, finalPath) + if err != nil { + return err + } + + // After successfully enqueueing a job, wake up one of the waiting goroutines, if any. + q.cond.Broadcast() + + return nil +} + +func (q *Queue) Dequeue() (interface{}, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.closed { + return nil, ErrQueueClosed + } + + files, err := os.ReadDir(q.path) + if err != nil { + return nil, err + } + + // Sort job files by name (which is the timestamp) + sort.Slice(files, func(i, j int) bool { + return files[i].Name() < files[j].Name() + }) + + for _, file := range files { + if filepath.Ext(file.Name()) == jobFileExt { + jobPath := filepath.Join(q.path, file.Name()) + + data, err := os.ReadFile(jobPath) + if err != nil { + return nil, err + } + + // Decode the bytes into an object + obj := q.builder() + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(obj); err != nil { + return nil, ErrUnableToDecode{ + Path: q.path, + Err: err, + } + } + + // Remove the job file + if err := os.Remove(jobPath); err != nil { + return nil, err + } + + return obj, nil + } + } + + return nil, errors.New("no jobs in queue") +} + +func (q *Queue) DequeueBlock() (interface{}, error) { + q.mu.Lock() + defer q.mu.Unlock() + + for { + if q.closed { + return nil, ErrQueueClosed + } + + obj, err := q.dequeueJob() + if err != nil { + if errors.Is(err, ErrNoJobs) { + // No jobs in queue, wait for a new job to be enqueued + q.cond.Wait() + } else { + return nil, err + } + } else { + return obj, nil + } + } +} + +// dequeueJob tries to dequeue a job from the queue without waiting. +// It returns an ErrNoJobs error if there are no jobs to dequeue. +func (q *Queue) dequeueJob() (interface{}, error) { + files, err := os.ReadDir(q.path) + if err != nil { + return nil, err + } + + // Sort job files by name (which is the timestamp) + sort.Slice(files, func(i, j int) bool { + return files[i].Name() < files[j].Name() + }) + + for _, file := range files { + if filepath.Ext(file.Name()) == jobFileExt { + jobPath := filepath.Join(q.path, file.Name()) + + data, err := os.ReadFile(jobPath) + if err != nil { + return nil, err + } + + // Decode the bytes into an object + obj := q.builder() + if err := gob.NewDecoder(bytes.NewReader(data)).Decode(obj); err != nil { + return nil, ErrUnableToDecode{ + Path: q.path, + Err: err, + } + } + + // Remove the job file + if err := os.Remove(jobPath); err != nil { + return nil, err + } + + return obj, nil + } + } + + return nil, ErrNoJobs // Custom error to indicate no jobs are available +} + +func (q *Queue) nextAvailableFilename() (string, error) { + timestamp := time.Now().UnixNano() + filename := strconv.FormatInt(timestamp, 10) + for { + _, err := os.Stat(filepath.Join(q.path, filename+jobFileExt)) + if os.IsNotExist(err) { + break + } else if err != nil { + return "", err + } else { + timestamp++ + filename = strconv.FormatInt(timestamp, 10) + } + } + return filename, nil +} + +// check the currentSize of the queue +// We do a lot of disk operations here, could find a more performant approach +func (q *Queue) currentSize() (int, error) { + files, err := os.ReadDir(q.path) + if err != nil { + return 0, err + } + + count := 0 + for _, file := range files { + if filepath.Ext(file.Name()) == jobFileExt { + count++ + } + } + return count, nil +} + +func isOverlyPermissive(path string) (bool, error) { + fs := afero.NewOsFs() + stat, err := fs.Stat(path) + if err != nil { + return true, errors.New("failed to analyze " + path) + } + mode := stat.Mode() + if mode&0o022 != 0 { + return true, nil + } + return false, nil +} diff --git a/policy/scan/pdque/queue_test.go b/policy/scan/pdque/queue_test.go new file mode 100644 index 00000000..ca23f6a8 --- /dev/null +++ b/policy/scan/pdque/queue_test.go @@ -0,0 +1,357 @@ +// Copyright (c) Mondoo, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package pdque + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNextAvailableFilename(t *testing.T) { + // Create a temporary directory for testing + testDir, err := os.MkdirTemp("", "diskqueue_test") + require.NoError(t, err) + defer os.RemoveAll(testDir) // Clean up + + // Initialize a new Queue + q := &Queue{Name: "testQueue", path: testDir} + + var timestamps sync.Map + var wg sync.WaitGroup + var mu sync.Mutex // Mutex to protect map operations + + // Use a channel to collect errors from goroutines + errChan := make(chan error, 1010) // Buffer should be the number of goroutines + + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + filename, err := q.nextAvailableFilename() + if err != nil { + errChan <- fmt.Errorf("failed to generate filename: %s", err) + return + } + mu.Lock() + if _, exists := timestamps.Load(filename); exists { + errChan <- fmt.Errorf("duplicate filename generated: %s", filename) + } else { + timestamps.Store(filename, struct{}{}) + } + mu.Unlock() + + // Create a file to simulate an existing job + filePath := filepath.Join(testDir, filename+jobFileExt) + if err := os.WriteFile(filePath, []byte("test"), 0o644); err != nil { + errChan <- fmt.Errorf("failed to write test file: %s", err) + } + }() + } + + // Test that filenames do not collide in a tight loop + for i := 0; i < 1000; i++ { + wg.Add(1) + go func() { + defer wg.Done() + filename, err := q.nextAvailableFilename() + require.NoError(t, err) + timestamps.Store(filename, struct{}{}) + }() + } + + // Test if the function properly handles existing files with incremented names + filename, err := q.nextAvailableFilename() + require.NoError(t, err) + + // Manually create files that would conflict to ensure our function increments properly + baseFilename := strings.TrimSuffix(filename, jobFileExt) + for i := 0; i < 3; i++ { + conflictFilename := baseFilename + "_" + strconv.Itoa(i) + jobFileExt + filePath := filepath.Join(testDir, conflictFilename) + err := os.WriteFile(filePath, []byte("test"), 0o644) + require.NoError(t, err) + } + + // Next filename should be incremented + nextFilename, err := q.nextAvailableFilename() + require.NoError(t, err) + + if nextFilename == filename { + t.Errorf("Expected next filename to be different, got: %s", nextFilename) + } + + // Wait for all goroutines to finish + wg.Wait() + close(errChan) + + // Check for any errors sent by goroutines + for err := range errChan { + t.Error(err) + } +} + +func TestEnqueue(t *testing.T) { + // Setup: create a temporary directory to act as the queue directory. + testDir, err := os.MkdirTemp("", "test_queue") + require.NoError(t, err) + defer os.RemoveAll(testDir) + + // Instantiate the Queue. + queue, err := New("testQueue", testDir, 1000, func() interface{} { + return nil + }) + require.NoError(t, err) + + testJob := struct { + Data string + }{ + Data: "test data", + } + + // Enqueue the job. + err = queue.Enqueue([]byte(fmt.Sprintf("%v", testJob))) + require.NoError(t, err) + + // Verify that a file has been created in the queue directory. + files, err := os.ReadDir(testDir) + require.NoError(t, err) + + if len(files) != 1 { + t.Fatalf("Expected 1 file in queue directory, found %d", len(files)) + } +} + +func TestClose(t *testing.T) { + testDir, err := os.MkdirTemp("", "test_queue") + require.NoError(t, err) + defer os.RemoveAll(testDir) + + // Create some temporary files to simulate the state of the queue with pending jobs. + tempFiles := []string{".tmp1", ".tmp2", ".tmp3"} + for _, f := range tempFiles { + tmpFilePath := filepath.Join(testDir, f) + err := os.WriteFile(tmpFilePath, []byte("data"), 0o644) + require.NoError(t, err) + } + + // Instantiate the Queue. + queue, err := New("testQueue", testDir, 1000, func() interface{} { + return nil + }) + require.NoError(t, err) + + // Close the queue. + err = queue.Close() + require.NoError(t, err) + + // Verify that the queue is marked as closed. + if !queue.closed { + t.Errorf("Queue should be marked as closed.") + } + + // Verify that temporary files are cleaned up. + files, err := os.ReadDir(testDir) + require.NoError(t, err) + + for _, file := range files { + if strings.HasPrefix(file.Name(), ".") { + t.Errorf("Temporary file %s was not cleaned up", file.Name()) + } + } + + // Verify that no new actions can be performed on the queue. + err = queue.Enqueue([]byte("data")) + require.Error(t, err) + + _, err = queue.Dequeue() + require.Error(t, err) +} + +type testObj struct { + Name string + ID int +} + +func TestEnqueueDequeue(t *testing.T) { + testDir, err := os.MkdirTemp("", "test_enqueue_dequeue") + require.NoError(t, err) + defer os.RemoveAll(testDir) + // Setup: Initialize the queue + queue, err := NewOrOpen("testQueue", testDir, 10, func() interface{} { + return new(testObj) + }) + require.NoError(t, err) + defer queue.Close() + + // Test enqueue + testObj := &testObj{Name: "test"} + err = queue.Enqueue(testObj) + require.NoError(t, err) + + // Test dequeue + dequeuedObj, err := queue.Dequeue() + require.NoError(t, err) + + assert.Equal(t, testObj, dequeuedObj) +} + +func TestQueueMaxSize(t *testing.T) { + testDir, err := os.MkdirTemp("", "test_maxSize") + require.NoError(t, err) + defer os.RemoveAll(testDir) + maxSize := 5 + queue, err := NewOrOpen("testQueue", testDir, maxSize, func() interface{} { + return new(testObj) + }) + require.NoError(t, err) + defer queue.Close() + + // Enqueue items up to the maximum size + for i := 0; i < maxSize; i++ { + err := queue.Enqueue(&testObj{Name: fmt.Sprintf("test%d", i)}) + require.NoError(t, err) + } + + // Attempt to enqueue one more item, which should fail + err = queue.Enqueue(&testObj{Name: "overflow"}) + require.Error(t, err) + + if !errors.Is(err, ErrQueueFull) { + t.Errorf("Expected ErrQueueFull, but got %v", err) + } + + // Dequeue an item + _, err = queue.Dequeue() + require.NoError(t, err) + + // Attempt to enqueue again, which should now succeed + err = queue.Enqueue(&testObj{Name: "shouldSucceed"}) + require.NoError(t, err) +} + +// TestEnqueueDequeue tests enqueuing and dequeuing of jobs +func TestEnqueueDequeueMore(t *testing.T) { + testDir, err := os.MkdirTemp("", "test_enqueue_dequeue") + require.NoError(t, err) + defer os.RemoveAll(testDir) + + // Create a new queue + q, err := NewOrOpen("testQueue", testDir, 1000, func() interface{} { return new(testObj) }) + require.NoError(t, err) + defer q.Close() + + // Enqueue 1000 jobs + for i := 0; i < 1000; i++ { + err := q.Enqueue(&testObj{ID: i}) + require.NoError(t, err) + } + + // Verify there are 1000 job files + jobCount, err := countJobFiles(testDir) + require.NoError(t, err) + + require.Equal(t, jobCount, 1000) + + // Dequeue and check each job + for i := 0; i < 1000; i++ { + obj, err := q.Dequeue() + require.NoError(t, err) + + job, ok := obj.(*testObj) + if !ok { + t.Fatalf("Dequeued object is not of type *TestJob") + } + + assert.Equal(t, job.ID, i) + } + + if obj, _ := q.Dequeue(); obj != nil { + t.Errorf("Expected queue to be empty, but got a job") + } +} + +// countJobFiles counts the number of job files in the given directory +func countJobFiles(dir string) (int, error) { + var count int + err := filepath.Walk(dir, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + if filepath.Ext(info.Name()) == jobFileExt { + count++ + } + return nil + }) + return count, err +} + +// TestConcurrentEnqueueDequeue tests concurrent enqueuing and dequeuing of jobs +func TestConcurrentEnqueueDequeue(t *testing.T) { + testDir, err := os.MkdirTemp("", "test_enqueue_dequeue") + require.NoError(t, err) + defer os.RemoveAll(testDir) + + const numJobs = 200 + + // Create a new queue + q, err := NewOrOpen("testConcurrentQueue", testDir, numJobs, func() interface{} { return &testObj{} }) + require.NoError(t, err) + defer q.Close() + + var wg sync.WaitGroup + + // Concurrently enqueue jobs + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < (numJobs / 2); i++ { + err := q.Enqueue(&testObj{ID: i}) + require.NoError(t, err) + } + }() + go func() { + defer wg.Done() + for i := 100; i < numJobs; i++ { + err := q.Enqueue(&testObj{ID: i}) + require.NoError(t, err) + } + }() + + // Concurrently dequeue jobs + dequeuedJobs := make(map[int]bool) + var mu sync.Mutex + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < numJobs; i++ { + obj, err := q.DequeueBlock() + require.NoError(t, err) + + job, ok := obj.(*testObj) + if !ok { + t.Errorf("Dequeued object is not of type *TestJob") + continue + } + + mu.Lock() + dequeuedJobs[job.ID] = true + mu.Unlock() + } + }() + + wg.Wait() + + // Verify all jobs were dequeued + assert.Equal(t, len(dequeuedJobs), numJobs) +}