diff --git a/changelog/27927.txt b/changelog/27927.txt new file mode 100644 index 000000000000..afc37a7acbd3 --- /dev/null +++ b/changelog/27927.txt @@ -0,0 +1,6 @@ +```release-note:improvement +storage/s3: Pass context to AWS SDK calls +``` +```release-note:improvement +storage/dynamodb: Pass context to AWS SDK calls +``` diff --git a/physical/dynamodb/dynamodb.go b/physical/dynamodb/dynamodb.go index c4484d20d446..bc27def0c987 100644 --- a/physical/dynamodb/dynamodb.go +++ b/physical/dynamodb/dynamodb.go @@ -294,7 +294,7 @@ func (d *DynamoDBBackend) Put(ctx context.Context, entry *physical.Entry) error }) } - return d.batchWriteRequests(requests) + return d.batchWriteRequests(ctx, requests) } // Get is used to fetch an entry @@ -304,7 +304,7 @@ func (d *DynamoDBBackend) Get(ctx context.Context, key string) (*physical.Entry, d.permitPool.Acquire() defer d.permitPool.Release() - resp, err := d.client.GetItem(&dynamodb.GetItemInput{ + resp, err := d.client.GetItemWithContext(ctx, &dynamodb.GetItemInput{ TableName: aws.String(d.table), ConsistentRead: aws.Bool(true), Key: map[string]*dynamodb.AttributeValue{ @@ -363,7 +363,7 @@ func (d *DynamoDBBackend) Delete(ctx context.Context, key string) error { excluded = append(excluded, recordKeyForVaultKey(prefixes[index-1])) } - hasChildren, err := d.hasChildren(prefix, excluded) + hasChildren, err := d.hasChildren(ctx, prefix, excluded) if err != nil { return err } @@ -387,7 +387,7 @@ func (d *DynamoDBBackend) Delete(ctx context.Context, key string) error { } } - return d.batchWriteRequests(requests) + return d.batchWriteRequests(ctx, requests) } // List is used to list all the keys under a given @@ -420,7 +420,7 @@ func (d *DynamoDBBackend) List(ctx context.Context, prefix string) ([]string, er d.permitPool.Acquire() defer d.permitPool.Release() - err := d.client.QueryPages(queryInput, func(out *dynamodb.QueryOutput, lastPage bool) bool { + err := d.client.QueryPagesWithContext(ctx, queryInput, func(out *dynamodb.QueryOutput, lastPage bool) bool { var record DynamoDBRecord for _, item := range out.Items { dynamodbattribute.UnmarshalMap(item, &record) @@ -443,7 +443,7 @@ func (d *DynamoDBBackend) List(ctx context.Context, prefix string) ([]string, er // before any deletes take place. To account for that hasChildren accepts a slice of // strings representing values we expect to find that should NOT be counted as children // because they are going to be deleted. -func (d *DynamoDBBackend) hasChildren(prefix string, exclude []string) (bool, error) { +func (d *DynamoDBBackend) hasChildren(ctx context.Context, prefix string, exclude []string) (bool, error) { prefix = strings.TrimSuffix(prefix, "/") prefix = escapeEmptyPath(prefix) @@ -473,7 +473,7 @@ func (d *DynamoDBBackend) hasChildren(prefix string, exclude []string) (bool, er d.permitPool.Acquire() defer d.permitPool.Release() - out, err := d.client.Query(queryInput) + out, err := d.client.QueryWithContext(ctx, queryInput) if err != nil { return false, err } @@ -519,7 +519,7 @@ func (d *DynamoDBBackend) HAEnabled() bool { // batchWriteRequests takes a list of write requests and executes them in badges // with a maximum size of 25 (which is the limit of BatchWriteItem requests). -func (d *DynamoDBBackend) batchWriteRequests(requests []*dynamodb.WriteRequest) error { +func (d *DynamoDBBackend) batchWriteRequests(ctx context.Context, requests []*dynamodb.WriteRequest) error { for len(requests) > 0 { batchSize := int(math.Min(float64(len(requests)), 25)) batch := map[string][]*dynamodb.WriteRequest{d.table: requests[:batchSize]} @@ -534,7 +534,7 @@ func (d *DynamoDBBackend) batchWriteRequests(requests []*dynamodb.WriteRequest) for len(batch) > 0 { var output *dynamodb.BatchWriteItemOutput - output, err = d.client.BatchWriteItem(&dynamodb.BatchWriteItemInput{ + output, err = d.client.BatchWriteItemWithContext(ctx, &dynamodb.BatchWriteItemInput{ RequestItems: batch, }) if err != nil { diff --git a/physical/s3/s3.go b/physical/s3/s3.go index da82acccd3ca..b1687a91622e 100644 --- a/physical/s3/s3.go +++ b/physical/s3/s3.go @@ -183,7 +183,7 @@ func (s *S3Backend) Put(ctx context.Context, entry *physical.Entry) error { putObjectInput.SSEKMSKeyId = aws.String(s.kmsKeyId) } - _, err := s.client.PutObject(putObjectInput) + _, err := s.client.PutObjectWithContext(ctx, putObjectInput) if err != nil { return err } @@ -201,7 +201,7 @@ func (s *S3Backend) Get(ctx context.Context, key string) (*physical.Entry, error // Setup key key = path.Join(s.path, key) - resp, err := s.client.GetObject(&s3.GetObjectInput{ + resp, err := s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), }) @@ -254,7 +254,7 @@ func (s *S3Backend) Delete(ctx context.Context, key string) error { // Setup key key = path.Join(s.path, key) - _, err := s.client.DeleteObject(&s3.DeleteObjectInput{ + _, err := s.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(key), }) @@ -289,7 +289,7 @@ func (s *S3Backend) List(ctx context.Context, prefix string) ([]string, error) { keys := []string{} - err := s.client.ListObjectsV2Pages(params, + err := s.client.ListObjectsV2PagesWithContext(ctx, params, func(page *s3.ListObjectsV2Output, lastPage bool) bool { if page != nil { // Add truncated 'folder' paths