From cb8c8372c00be53f7e872b2a830dcb4525ec3a3c Mon Sep 17 00:00:00 2001 From: Bill Havanki Date: Tue, 14 May 2024 12:25:50 -0400 Subject: [PATCH] chore: Upgrade to aws-sdk-go-v2 (#481) This is a major update to chamber's support for the S3, SSM, and Secrets Manager store implementations. Every effort was made to preserve functionality, but there is one gap. The v2 SDK does not expose a retryer field for a minimum throttle delay, so that argument is currently ignored when constructing new SSM stores. Support for the delay will be addressed later. The v2 SDK does not offer "iface" interfaces for the various clients, so instead interfaces tailored to what chamber uses are defined. For testing, these new interfaces are mocked, and mock types are generated using github.com/matryer/moq. You don't need moq to use chamber or even to build it, but only if you are developing chamber and make a change to an API interface. Also, old code in the SSM store implementation that allowed it to work without IAM permissions for ssm:GetParametersByPath has been eliminated. The permissions have been expected for a long time now. Co-authored-by: Ryan McKern <344926+mckern@users.noreply.github.com> --- .gitattributes | 2 + Makefile | 15 +- go.mod | 21 +- go.sum | 44 +- store/awsapi.go | 47 ++ store/awsapi_mock.go | 997 ++++++++++++++++++++++++++++++ store/s3store.go | 70 +-- store/s3storeKMS.go | 91 ++- store/secretsmanagerstore.go | 61 +- store/secretsmanagerstore_test.go | 185 +++--- store/shared.go | 52 +- store/ssmstore.go | 189 +++--- store/ssmstore_test.go | 301 +++++---- 13 files changed, 1575 insertions(+), 500 deletions(-) create mode 100644 .gitattributes create mode 100644 store/awsapi.go create mode 100644 store/awsapi_mock.go diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..f471523c --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +/go.sum linguist-generated=true +/store/awsapi_mock.go linguist-generated=true diff --git a/Makefile b/Makefile index d58f9542..744a84ac 100644 --- a/Makefile +++ b/Makefile @@ -18,10 +18,21 @@ VERSION_MAJOR_MINOR := $(shell echo "$(VERSION)" | sed 's/^v\([0-9]*.[0-9]*\).*/ VERSION_MAJOR := $(shell echo "$(VERSION)" | sed 's/^v\([0-9]*\).*/\1/') ANALYTICS_WRITE_KEY ?= LDFLAGS := -ldflags='-X "main.Version=$(VERSION)" -X "main.AnalyticsWriteKey=$(ANALYTICS_WRITE_KEY)"' +MOQ := $(shell command -v moq 2> /dev/null) +SRC := $(shell find . -name '*.go') -test: +test: store/awsapi_mock.go go test -v ./... +store/awsapi_mock.go: store/awsapi.go +ifdef MOQ + rm -f $@ + go generate ./... +else + @echo "Unable to generate mocks" + @echo "Please install moq: go install github.com/matryer/moq@latest" +endif + all: dist/chamber-$(VERSION)-darwin-amd64 dist/chamber-$(VERSION)-linux-amd64 dist/chamber-$(VERSION)-windows-amd64.exe clean: @@ -32,7 +43,7 @@ dist/: build: chamber -chamber: +chamber: $(SRC) CGO_ENABLED=0 go build -trimpath $(LDFLAGS) -o $@ dist/chamber-$(VERSION)-darwin-amd64: | dist/ diff --git a/go.mod b/go.mod index 5ee2f808..8198ff7b 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,14 @@ go 1.20 require ( github.com/alessio/shellescape v1.4.2 - github.com/aws/aws-sdk-go v1.51.21 + github.com/aws/aws-sdk-go-v2 v1.26.1 + github.com/aws/aws-sdk-go-v2/config v1.27.11 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 + github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 + github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6 + github.com/aws/aws-sdk-go-v2/service/ssm v1.49.5 + github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 + github.com/aws/smithy-go v1.20.2 github.com/magiconair/properties v1.8.7 github.com/segmentio/analytics-go/v3 v3.3.0 github.com/spf13/cobra v1.8.0 @@ -14,6 +21,18 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 // indirect github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/google/uuid v1.3.1 // indirect diff --git a/go.sum b/go.sum index fd9c0e29..a005b535 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,45 @@ github.com/alessio/shellescape v1.4.2 h1:MHPfaU+ddJ0/bYWpgIeUnQUqKrlJ1S7BfEYPM4uEoM0= github.com/alessio/shellescape v1.4.2/go.mod h1:PZAiSCk0LJaZkiCSkPv8qIobYglO3FPpyFjDCtHLS30= -github.com/aws/aws-sdk-go v1.51.21 h1:UrT6JC9R9PkYYXDZBV0qDKTualMr+bfK2eboTknMgbs= -github.com/aws/aws-sdk-go v1.51.21/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= +github.com/aws/aws-sdk-go-v2 v1.26.1 h1:5554eUqIYVWpU0YmeeYZ0wU64H2VLBs8TlhRB2L+EkA= +github.com/aws/aws-sdk-go-v2 v1.26.1/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 h1:x6xsQXGSmW6frevwDA+vi/wqhp1ct18mVXYN08/93to= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2/go.mod h1:lPprDr1e6cJdyYeGXnRaJoP4Md+cDBvi2eOj00BlGmg= +github.com/aws/aws-sdk-go-v2/config v1.27.11 h1:f47rANd2LQEYHda2ddSCKYId18/8BhSRM4BULGmfgNA= +github.com/aws/aws-sdk-go-v2/config v1.27.11/go.mod h1:SMsV78RIOYdve1vf36z8LmnszlRWkwMQtomCAI0/mIE= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11 h1:YuIB1dJNf1Re822rriUOTxopaHHvIq0l/pX3fwO+Tzs= +github.com/aws/aws-sdk-go-v2/credentials v1.17.11/go.mod h1:AQtFPsDH9bI2O+71anW6EKL+NcD7LG3dpKGMV4SShgo= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1 h1:FVJ0r5XTHSmIHJV6KuDmdYhEpvlHpiSd38RQWhut5J4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.1/go.mod h1:zusuAeqezXzAB24LGuzuekqMAEgWkVYukBec3kr3jUg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 h1:aw39xVGeRWlWx9EzGVnhOR4yOjQDHPQ6o6NmBlscyQg= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5/go.mod h1:FSaRudD0dXiMPK2UjknVwwTYyZMRsHv3TtkabsZih5I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 h1:PG1F3OD1szkuQPzDw3CIQsRIrtTlUC3lP84taWzHlq0= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5/go.mod h1:jU1li6RFryMz+so64PpKtudI+QzbKoIEivqdf6LNpOc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0 h1:hT8rVHwugYE2lEfdFE0QWVo81lF7jMrYJVDWI+f+VxU= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.0/go.mod h1:8tu/lYfQfFe6IGnaOdrpVgEL2IrrDOf6/m9RQum4NkY= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5 h1:81KE7vaZzrl7yHBYHVEzYB8sypz11NMOZ40YlWvPxsU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.5/go.mod h1:LIt2rg7Mcgn09Ygbdh/RdIm0rQ+3BNkbP1gyVMFtRK0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2 h1:Ji0DY1xUsUr3I8cHps0G+XM3WWU16lP6yG8qu1GAZAs= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.2/go.mod h1:5CsjAbs3NlGQyZNFACh+zztPDI7fU6eW9QsxjfnuBKg= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7 h1:ZMeFZ5yk+Ek+jNr1+uwCd2tG89t6oTS5yVWpa6yy2es= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.7/go.mod h1:mxV05U+4JiHqIpGqqYXOHLPKUC6bDXC44bsUhNjOEwY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7 h1:ogRAwT1/gxJBcSWDMZlgyFUM962F51A5CRhDLbxLdmo= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.7/go.mod h1:YCsIZhXfRPLFFCl5xxY+1T9RKzOKjCut+28JSX2DnAk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5 h1:f9RyWNtS8oH7cZlbn+/JNPpjUk5+5fLd5lM9M0i49Ys= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.5/go.mod h1:h5CoMZV2VF297/VLhRhO1WF+XYWOzXo+4HsObA4HjBQ= +github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1 h1:6cnno47Me9bRykw9AEv9zkXE+5or7jz8TsskTTccbgc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.53.1/go.mod h1:qmdkIIAC+GCLASF7R2whgNrJADz0QZPX+Seiw/i4S3o= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6 h1:TIOEjw0i2yyhmhRry3Oeu9YtiiHWISZ6j/irS1W3gX4= +github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.28.6/go.mod h1:3Ba++UwWd154xtP4FRX5pUK3Gt4up5sDHCve6kVfE+g= +github.com/aws/aws-sdk-go-v2/service/ssm v1.49.5 h1:KBwyHzP2QG8J//hoGuPyHWZ5tgL1BzaoMURUkecpI4g= +github.com/aws/aws-sdk-go-v2/service/ssm v1.49.5/go.mod h1:Ebk/HZmGhxWKDVxM4+pwbxGjm3RQOQLMjAEosI3ss9Q= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5 h1:vN8hEbpRnL7+Hopy9dzmRle1xmDc7o8tmY0klsr175w= +github.com/aws/aws-sdk-go-v2/service/sso v1.20.5/go.mod h1:qGzynb/msuZIE8I75DVRCUXw3o3ZyBmUvMwQ2t/BrGM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4 h1:Jux+gDDyi1Lruk+KHF91tK2KCuY61kzoCpvtvJJBtOE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.23.4/go.mod h1:mUYPBhaF2lGiukDEjJX2BLRRKTmoUSitGDUgM4tRxak= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6 h1:cwIxeBttqPN3qkaAjcEcsh8NYr8n2HZPkcKgPAi1phU= +github.com/aws/aws-sdk-go-v2/service/sts v1.28.6/go.mod h1:FZf1/nKNEkHdGGJP/cI2MoIMquumuRK6ol3QQJNDxmw= +github.com/aws/smithy-go v1.20.2 h1:tbp628ireGtzcHDDmLT/6ADHidqnwgF57XOXZe6tp4Q= +github.com/aws/smithy-go v1.20.2/go.mod h1:krry+ya/rV9RDcV/Q16kpu6ypI4K2czasz0NC3qS14E= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= @@ -40,10 +78,8 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= diff --git a/store/awsapi.go b/store/awsapi.go new file mode 100644 index 00000000..d1ee8608 --- /dev/null +++ b/store/awsapi.go @@ -0,0 +1,47 @@ +package store + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/sts" +) + +// The interfaces defined here collect together all of the SDK functions used +// throughout chamber. Code that works with AWS does so through these interfaces. +// The "real" AWS SDK client objects implement these interfaces, since they +// contain all of the methods (and more). Mock versions of these interfaces are +// generated using the moq utility for substitution in unit tests. For more, see +// https://aws.github.io/aws-sdk-go-v2/docs/unit-testing/ . + +//go:generate moq -out awsapi_mock.go . apiS3 apiSSM apiSTS apiSecretsManager + +type apiS3 interface { + DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) + GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) + PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) +} + +type apiSSM interface { + DeleteParameter(ctx context.Context, params *ssm.DeleteParameterInput, optFns ...func(*ssm.Options)) (*ssm.DeleteParameterOutput, error) + DescribeParameters(ctx context.Context, params *ssm.DescribeParametersInput, optFns ...func(*ssm.Options)) (*ssm.DescribeParametersOutput, error) + GetParameterHistory(ctx context.Context, params *ssm.GetParameterHistoryInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterHistoryOutput, error) + GetParameters(ctx context.Context, params *ssm.GetParametersInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersOutput, error) + GetParametersByPath(ctx context.Context, params *ssm.GetParametersByPathInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersByPathOutput, error) + PutParameter(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) +} + +type apiSTS interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) +} + +type apiSecretsManager interface { + CreateSecret(ctx context.Context, params *secretsmanager.CreateSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.CreateSecretOutput, error) + DescribeSecret(ctx context.Context, params *secretsmanager.DescribeSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.DescribeSecretOutput, error) + GetSecretValue(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) + ListSecretVersionIds(ctx context.Context, params *secretsmanager.ListSecretVersionIdsInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.ListSecretVersionIdsOutput, error) + PutSecretValue(ctx context.Context, params *secretsmanager.PutSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.PutSecretValueOutput, error) +} diff --git a/store/awsapi_mock.go b/store/awsapi_mock.go new file mode 100644 index 00000000..f01924e8 --- /dev/null +++ b/store/awsapi_mock.go @@ -0,0 +1,997 @@ +// Code generated by moq; DO NOT EDIT. +// github.com/matryer/moq + +package store + +import ( + "context" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/sts" + "sync" +) + +// Ensure, that apiS3Mock does implement apiS3. +// If this is not the case, regenerate this file with moq. +var _ apiS3 = &apiS3Mock{} + +// apiS3Mock is a mock implementation of apiS3. +// +// func TestSomethingThatUsesapiS3(t *testing.T) { +// +// // make and configure a mocked apiS3 +// mockedapiS3 := &apiS3Mock{ +// DeleteObjectFunc: func(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { +// panic("mock out the DeleteObject method") +// }, +// GetObjectFunc: func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { +// panic("mock out the GetObject method") +// }, +// ListObjectsV2Func: func(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { +// panic("mock out the ListObjectsV2 method") +// }, +// PutObjectFunc: func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { +// panic("mock out the PutObject method") +// }, +// } +// +// // use mockedapiS3 in code that requires apiS3 +// // and then make assertions. +// +// } +type apiS3Mock struct { + // DeleteObjectFunc mocks the DeleteObject method. + DeleteObjectFunc func(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) + + // GetObjectFunc mocks the GetObject method. + GetObjectFunc func(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) + + // ListObjectsV2Func mocks the ListObjectsV2 method. + ListObjectsV2Func func(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) + + // PutObjectFunc mocks the PutObject method. + PutObjectFunc func(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) + + // calls tracks calls to the methods. + calls struct { + // DeleteObject holds details about calls to the DeleteObject method. + DeleteObject []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *s3.DeleteObjectInput + // OptFns is the optFns argument value. + OptFns []func(*s3.Options) + } + // GetObject holds details about calls to the GetObject method. + GetObject []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *s3.GetObjectInput + // OptFns is the optFns argument value. + OptFns []func(*s3.Options) + } + // ListObjectsV2 holds details about calls to the ListObjectsV2 method. + ListObjectsV2 []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *s3.ListObjectsV2Input + // OptFns is the optFns argument value. + OptFns []func(*s3.Options) + } + // PutObject holds details about calls to the PutObject method. + PutObject []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *s3.PutObjectInput + // OptFns is the optFns argument value. + OptFns []func(*s3.Options) + } + } + lockDeleteObject sync.RWMutex + lockGetObject sync.RWMutex + lockListObjectsV2 sync.RWMutex + lockPutObject sync.RWMutex +} + +// DeleteObject calls DeleteObjectFunc. +func (mock *apiS3Mock) DeleteObject(ctx context.Context, params *s3.DeleteObjectInput, optFns ...func(*s3.Options)) (*s3.DeleteObjectOutput, error) { + if mock.DeleteObjectFunc == nil { + panic("apiS3Mock.DeleteObjectFunc: method is nil but apiS3.DeleteObject was just called") + } + callInfo := struct { + Ctx context.Context + Params *s3.DeleteObjectInput + OptFns []func(*s3.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockDeleteObject.Lock() + mock.calls.DeleteObject = append(mock.calls.DeleteObject, callInfo) + mock.lockDeleteObject.Unlock() + return mock.DeleteObjectFunc(ctx, params, optFns...) +} + +// DeleteObjectCalls gets all the calls that were made to DeleteObject. +// Check the length with: +// +// len(mockedapiS3.DeleteObjectCalls()) +func (mock *apiS3Mock) DeleteObjectCalls() []struct { + Ctx context.Context + Params *s3.DeleteObjectInput + OptFns []func(*s3.Options) +} { + var calls []struct { + Ctx context.Context + Params *s3.DeleteObjectInput + OptFns []func(*s3.Options) + } + mock.lockDeleteObject.RLock() + calls = mock.calls.DeleteObject + mock.lockDeleteObject.RUnlock() + return calls +} + +// GetObject calls GetObjectFunc. +func (mock *apiS3Mock) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + if mock.GetObjectFunc == nil { + panic("apiS3Mock.GetObjectFunc: method is nil but apiS3.GetObject was just called") + } + callInfo := struct { + Ctx context.Context + Params *s3.GetObjectInput + OptFns []func(*s3.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetObject.Lock() + mock.calls.GetObject = append(mock.calls.GetObject, callInfo) + mock.lockGetObject.Unlock() + return mock.GetObjectFunc(ctx, params, optFns...) +} + +// GetObjectCalls gets all the calls that were made to GetObject. +// Check the length with: +// +// len(mockedapiS3.GetObjectCalls()) +func (mock *apiS3Mock) GetObjectCalls() []struct { + Ctx context.Context + Params *s3.GetObjectInput + OptFns []func(*s3.Options) +} { + var calls []struct { + Ctx context.Context + Params *s3.GetObjectInput + OptFns []func(*s3.Options) + } + mock.lockGetObject.RLock() + calls = mock.calls.GetObject + mock.lockGetObject.RUnlock() + return calls +} + +// ListObjectsV2 calls ListObjectsV2Func. +func (mock *apiS3Mock) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + if mock.ListObjectsV2Func == nil { + panic("apiS3Mock.ListObjectsV2Func: method is nil but apiS3.ListObjectsV2 was just called") + } + callInfo := struct { + Ctx context.Context + Params *s3.ListObjectsV2Input + OptFns []func(*s3.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockListObjectsV2.Lock() + mock.calls.ListObjectsV2 = append(mock.calls.ListObjectsV2, callInfo) + mock.lockListObjectsV2.Unlock() + return mock.ListObjectsV2Func(ctx, params, optFns...) +} + +// ListObjectsV2Calls gets all the calls that were made to ListObjectsV2. +// Check the length with: +// +// len(mockedapiS3.ListObjectsV2Calls()) +func (mock *apiS3Mock) ListObjectsV2Calls() []struct { + Ctx context.Context + Params *s3.ListObjectsV2Input + OptFns []func(*s3.Options) +} { + var calls []struct { + Ctx context.Context + Params *s3.ListObjectsV2Input + OptFns []func(*s3.Options) + } + mock.lockListObjectsV2.RLock() + calls = mock.calls.ListObjectsV2 + mock.lockListObjectsV2.RUnlock() + return calls +} + +// PutObject calls PutObjectFunc. +func (mock *apiS3Mock) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + if mock.PutObjectFunc == nil { + panic("apiS3Mock.PutObjectFunc: method is nil but apiS3.PutObject was just called") + } + callInfo := struct { + Ctx context.Context + Params *s3.PutObjectInput + OptFns []func(*s3.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockPutObject.Lock() + mock.calls.PutObject = append(mock.calls.PutObject, callInfo) + mock.lockPutObject.Unlock() + return mock.PutObjectFunc(ctx, params, optFns...) +} + +// PutObjectCalls gets all the calls that were made to PutObject. +// Check the length with: +// +// len(mockedapiS3.PutObjectCalls()) +func (mock *apiS3Mock) PutObjectCalls() []struct { + Ctx context.Context + Params *s3.PutObjectInput + OptFns []func(*s3.Options) +} { + var calls []struct { + Ctx context.Context + Params *s3.PutObjectInput + OptFns []func(*s3.Options) + } + mock.lockPutObject.RLock() + calls = mock.calls.PutObject + mock.lockPutObject.RUnlock() + return calls +} + +// Ensure, that apiSSMMock does implement apiSSM. +// If this is not the case, regenerate this file with moq. +var _ apiSSM = &apiSSMMock{} + +// apiSSMMock is a mock implementation of apiSSM. +// +// func TestSomethingThatUsesapiSSM(t *testing.T) { +// +// // make and configure a mocked apiSSM +// mockedapiSSM := &apiSSMMock{ +// DeleteParameterFunc: func(ctx context.Context, params *ssm.DeleteParameterInput, optFns ...func(*ssm.Options)) (*ssm.DeleteParameterOutput, error) { +// panic("mock out the DeleteParameter method") +// }, +// DescribeParametersFunc: func(ctx context.Context, params *ssm.DescribeParametersInput, optFns ...func(*ssm.Options)) (*ssm.DescribeParametersOutput, error) { +// panic("mock out the DescribeParameters method") +// }, +// GetParameterHistoryFunc: func(ctx context.Context, params *ssm.GetParameterHistoryInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterHistoryOutput, error) { +// panic("mock out the GetParameterHistory method") +// }, +// GetParametersFunc: func(ctx context.Context, params *ssm.GetParametersInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersOutput, error) { +// panic("mock out the GetParameters method") +// }, +// GetParametersByPathFunc: func(ctx context.Context, params *ssm.GetParametersByPathInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersByPathOutput, error) { +// panic("mock out the GetParametersByPath method") +// }, +// PutParameterFunc: func(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { +// panic("mock out the PutParameter method") +// }, +// } +// +// // use mockedapiSSM in code that requires apiSSM +// // and then make assertions. +// +// } +type apiSSMMock struct { + // DeleteParameterFunc mocks the DeleteParameter method. + DeleteParameterFunc func(ctx context.Context, params *ssm.DeleteParameterInput, optFns ...func(*ssm.Options)) (*ssm.DeleteParameterOutput, error) + + // DescribeParametersFunc mocks the DescribeParameters method. + DescribeParametersFunc func(ctx context.Context, params *ssm.DescribeParametersInput, optFns ...func(*ssm.Options)) (*ssm.DescribeParametersOutput, error) + + // GetParameterHistoryFunc mocks the GetParameterHistory method. + GetParameterHistoryFunc func(ctx context.Context, params *ssm.GetParameterHistoryInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterHistoryOutput, error) + + // GetParametersFunc mocks the GetParameters method. + GetParametersFunc func(ctx context.Context, params *ssm.GetParametersInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersOutput, error) + + // GetParametersByPathFunc mocks the GetParametersByPath method. + GetParametersByPathFunc func(ctx context.Context, params *ssm.GetParametersByPathInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersByPathOutput, error) + + // PutParameterFunc mocks the PutParameter method. + PutParameterFunc func(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) + + // calls tracks calls to the methods. + calls struct { + // DeleteParameter holds details about calls to the DeleteParameter method. + DeleteParameter []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.DeleteParameterInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + // DescribeParameters holds details about calls to the DescribeParameters method. + DescribeParameters []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.DescribeParametersInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + // GetParameterHistory holds details about calls to the GetParameterHistory method. + GetParameterHistory []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.GetParameterHistoryInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + // GetParameters holds details about calls to the GetParameters method. + GetParameters []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.GetParametersInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + // GetParametersByPath holds details about calls to the GetParametersByPath method. + GetParametersByPath []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.GetParametersByPathInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + // PutParameter holds details about calls to the PutParameter method. + PutParameter []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *ssm.PutParameterInput + // OptFns is the optFns argument value. + OptFns []func(*ssm.Options) + } + } + lockDeleteParameter sync.RWMutex + lockDescribeParameters sync.RWMutex + lockGetParameterHistory sync.RWMutex + lockGetParameters sync.RWMutex + lockGetParametersByPath sync.RWMutex + lockPutParameter sync.RWMutex +} + +// DeleteParameter calls DeleteParameterFunc. +func (mock *apiSSMMock) DeleteParameter(ctx context.Context, params *ssm.DeleteParameterInput, optFns ...func(*ssm.Options)) (*ssm.DeleteParameterOutput, error) { + if mock.DeleteParameterFunc == nil { + panic("apiSSMMock.DeleteParameterFunc: method is nil but apiSSM.DeleteParameter was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.DeleteParameterInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockDeleteParameter.Lock() + mock.calls.DeleteParameter = append(mock.calls.DeleteParameter, callInfo) + mock.lockDeleteParameter.Unlock() + return mock.DeleteParameterFunc(ctx, params, optFns...) +} + +// DeleteParameterCalls gets all the calls that were made to DeleteParameter. +// Check the length with: +// +// len(mockedapiSSM.DeleteParameterCalls()) +func (mock *apiSSMMock) DeleteParameterCalls() []struct { + Ctx context.Context + Params *ssm.DeleteParameterInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.DeleteParameterInput + OptFns []func(*ssm.Options) + } + mock.lockDeleteParameter.RLock() + calls = mock.calls.DeleteParameter + mock.lockDeleteParameter.RUnlock() + return calls +} + +// DescribeParameters calls DescribeParametersFunc. +func (mock *apiSSMMock) DescribeParameters(ctx context.Context, params *ssm.DescribeParametersInput, optFns ...func(*ssm.Options)) (*ssm.DescribeParametersOutput, error) { + if mock.DescribeParametersFunc == nil { + panic("apiSSMMock.DescribeParametersFunc: method is nil but apiSSM.DescribeParameters was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.DescribeParametersInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockDescribeParameters.Lock() + mock.calls.DescribeParameters = append(mock.calls.DescribeParameters, callInfo) + mock.lockDescribeParameters.Unlock() + return mock.DescribeParametersFunc(ctx, params, optFns...) +} + +// DescribeParametersCalls gets all the calls that were made to DescribeParameters. +// Check the length with: +// +// len(mockedapiSSM.DescribeParametersCalls()) +func (mock *apiSSMMock) DescribeParametersCalls() []struct { + Ctx context.Context + Params *ssm.DescribeParametersInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.DescribeParametersInput + OptFns []func(*ssm.Options) + } + mock.lockDescribeParameters.RLock() + calls = mock.calls.DescribeParameters + mock.lockDescribeParameters.RUnlock() + return calls +} + +// GetParameterHistory calls GetParameterHistoryFunc. +func (mock *apiSSMMock) GetParameterHistory(ctx context.Context, params *ssm.GetParameterHistoryInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterHistoryOutput, error) { + if mock.GetParameterHistoryFunc == nil { + panic("apiSSMMock.GetParameterHistoryFunc: method is nil but apiSSM.GetParameterHistory was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.GetParameterHistoryInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetParameterHistory.Lock() + mock.calls.GetParameterHistory = append(mock.calls.GetParameterHistory, callInfo) + mock.lockGetParameterHistory.Unlock() + return mock.GetParameterHistoryFunc(ctx, params, optFns...) +} + +// GetParameterHistoryCalls gets all the calls that were made to GetParameterHistory. +// Check the length with: +// +// len(mockedapiSSM.GetParameterHistoryCalls()) +func (mock *apiSSMMock) GetParameterHistoryCalls() []struct { + Ctx context.Context + Params *ssm.GetParameterHistoryInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.GetParameterHistoryInput + OptFns []func(*ssm.Options) + } + mock.lockGetParameterHistory.RLock() + calls = mock.calls.GetParameterHistory + mock.lockGetParameterHistory.RUnlock() + return calls +} + +// GetParameters calls GetParametersFunc. +func (mock *apiSSMMock) GetParameters(ctx context.Context, params *ssm.GetParametersInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersOutput, error) { + if mock.GetParametersFunc == nil { + panic("apiSSMMock.GetParametersFunc: method is nil but apiSSM.GetParameters was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.GetParametersInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetParameters.Lock() + mock.calls.GetParameters = append(mock.calls.GetParameters, callInfo) + mock.lockGetParameters.Unlock() + return mock.GetParametersFunc(ctx, params, optFns...) +} + +// GetParametersCalls gets all the calls that were made to GetParameters. +// Check the length with: +// +// len(mockedapiSSM.GetParametersCalls()) +func (mock *apiSSMMock) GetParametersCalls() []struct { + Ctx context.Context + Params *ssm.GetParametersInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.GetParametersInput + OptFns []func(*ssm.Options) + } + mock.lockGetParameters.RLock() + calls = mock.calls.GetParameters + mock.lockGetParameters.RUnlock() + return calls +} + +// GetParametersByPath calls GetParametersByPathFunc. +func (mock *apiSSMMock) GetParametersByPath(ctx context.Context, params *ssm.GetParametersByPathInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersByPathOutput, error) { + if mock.GetParametersByPathFunc == nil { + panic("apiSSMMock.GetParametersByPathFunc: method is nil but apiSSM.GetParametersByPath was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.GetParametersByPathInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetParametersByPath.Lock() + mock.calls.GetParametersByPath = append(mock.calls.GetParametersByPath, callInfo) + mock.lockGetParametersByPath.Unlock() + return mock.GetParametersByPathFunc(ctx, params, optFns...) +} + +// GetParametersByPathCalls gets all the calls that were made to GetParametersByPath. +// Check the length with: +// +// len(mockedapiSSM.GetParametersByPathCalls()) +func (mock *apiSSMMock) GetParametersByPathCalls() []struct { + Ctx context.Context + Params *ssm.GetParametersByPathInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.GetParametersByPathInput + OptFns []func(*ssm.Options) + } + mock.lockGetParametersByPath.RLock() + calls = mock.calls.GetParametersByPath + mock.lockGetParametersByPath.RUnlock() + return calls +} + +// PutParameter calls PutParameterFunc. +func (mock *apiSSMMock) PutParameter(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + if mock.PutParameterFunc == nil { + panic("apiSSMMock.PutParameterFunc: method is nil but apiSSM.PutParameter was just called") + } + callInfo := struct { + Ctx context.Context + Params *ssm.PutParameterInput + OptFns []func(*ssm.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockPutParameter.Lock() + mock.calls.PutParameter = append(mock.calls.PutParameter, callInfo) + mock.lockPutParameter.Unlock() + return mock.PutParameterFunc(ctx, params, optFns...) +} + +// PutParameterCalls gets all the calls that were made to PutParameter. +// Check the length with: +// +// len(mockedapiSSM.PutParameterCalls()) +func (mock *apiSSMMock) PutParameterCalls() []struct { + Ctx context.Context + Params *ssm.PutParameterInput + OptFns []func(*ssm.Options) +} { + var calls []struct { + Ctx context.Context + Params *ssm.PutParameterInput + OptFns []func(*ssm.Options) + } + mock.lockPutParameter.RLock() + calls = mock.calls.PutParameter + mock.lockPutParameter.RUnlock() + return calls +} + +// Ensure, that apiSTSMock does implement apiSTS. +// If this is not the case, regenerate this file with moq. +var _ apiSTS = &apiSTSMock{} + +// apiSTSMock is a mock implementation of apiSTS. +// +// func TestSomethingThatUsesapiSTS(t *testing.T) { +// +// // make and configure a mocked apiSTS +// mockedapiSTS := &apiSTSMock{ +// GetCallerIdentityFunc: func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { +// panic("mock out the GetCallerIdentity method") +// }, +// } +// +// // use mockedapiSTS in code that requires apiSTS +// // and then make assertions. +// +// } +type apiSTSMock struct { + // GetCallerIdentityFunc mocks the GetCallerIdentity method. + GetCallerIdentityFunc func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) + + // calls tracks calls to the methods. + calls struct { + // GetCallerIdentity holds details about calls to the GetCallerIdentity method. + GetCallerIdentity []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *sts.GetCallerIdentityInput + // OptFns is the optFns argument value. + OptFns []func(*sts.Options) + } + } + lockGetCallerIdentity sync.RWMutex +} + +// GetCallerIdentity calls GetCallerIdentityFunc. +func (mock *apiSTSMock) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + if mock.GetCallerIdentityFunc == nil { + panic("apiSTSMock.GetCallerIdentityFunc: method is nil but apiSTS.GetCallerIdentity was just called") + } + callInfo := struct { + Ctx context.Context + Params *sts.GetCallerIdentityInput + OptFns []func(*sts.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetCallerIdentity.Lock() + mock.calls.GetCallerIdentity = append(mock.calls.GetCallerIdentity, callInfo) + mock.lockGetCallerIdentity.Unlock() + return mock.GetCallerIdentityFunc(ctx, params, optFns...) +} + +// GetCallerIdentityCalls gets all the calls that were made to GetCallerIdentity. +// Check the length with: +// +// len(mockedapiSTS.GetCallerIdentityCalls()) +func (mock *apiSTSMock) GetCallerIdentityCalls() []struct { + Ctx context.Context + Params *sts.GetCallerIdentityInput + OptFns []func(*sts.Options) +} { + var calls []struct { + Ctx context.Context + Params *sts.GetCallerIdentityInput + OptFns []func(*sts.Options) + } + mock.lockGetCallerIdentity.RLock() + calls = mock.calls.GetCallerIdentity + mock.lockGetCallerIdentity.RUnlock() + return calls +} + +// Ensure, that apiSecretsManagerMock does implement apiSecretsManager. +// If this is not the case, regenerate this file with moq. +var _ apiSecretsManager = &apiSecretsManagerMock{} + +// apiSecretsManagerMock is a mock implementation of apiSecretsManager. +// +// func TestSomethingThatUsesapiSecretsManager(t *testing.T) { +// +// // make and configure a mocked apiSecretsManager +// mockedapiSecretsManager := &apiSecretsManagerMock{ +// CreateSecretFunc: func(ctx context.Context, params *secretsmanager.CreateSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.CreateSecretOutput, error) { +// panic("mock out the CreateSecret method") +// }, +// DescribeSecretFunc: func(ctx context.Context, params *secretsmanager.DescribeSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.DescribeSecretOutput, error) { +// panic("mock out the DescribeSecret method") +// }, +// GetSecretValueFunc: func(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { +// panic("mock out the GetSecretValue method") +// }, +// ListSecretVersionIdsFunc: func(ctx context.Context, params *secretsmanager.ListSecretVersionIdsInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.ListSecretVersionIdsOutput, error) { +// panic("mock out the ListSecretVersionIds method") +// }, +// PutSecretValueFunc: func(ctx context.Context, params *secretsmanager.PutSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.PutSecretValueOutput, error) { +// panic("mock out the PutSecretValue method") +// }, +// } +// +// // use mockedapiSecretsManager in code that requires apiSecretsManager +// // and then make assertions. +// +// } +type apiSecretsManagerMock struct { + // CreateSecretFunc mocks the CreateSecret method. + CreateSecretFunc func(ctx context.Context, params *secretsmanager.CreateSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.CreateSecretOutput, error) + + // DescribeSecretFunc mocks the DescribeSecret method. + DescribeSecretFunc func(ctx context.Context, params *secretsmanager.DescribeSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.DescribeSecretOutput, error) + + // GetSecretValueFunc mocks the GetSecretValue method. + GetSecretValueFunc func(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) + + // ListSecretVersionIdsFunc mocks the ListSecretVersionIds method. + ListSecretVersionIdsFunc func(ctx context.Context, params *secretsmanager.ListSecretVersionIdsInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.ListSecretVersionIdsOutput, error) + + // PutSecretValueFunc mocks the PutSecretValue method. + PutSecretValueFunc func(ctx context.Context, params *secretsmanager.PutSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.PutSecretValueOutput, error) + + // calls tracks calls to the methods. + calls struct { + // CreateSecret holds details about calls to the CreateSecret method. + CreateSecret []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *secretsmanager.CreateSecretInput + // OptFns is the optFns argument value. + OptFns []func(*secretsmanager.Options) + } + // DescribeSecret holds details about calls to the DescribeSecret method. + DescribeSecret []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *secretsmanager.DescribeSecretInput + // OptFns is the optFns argument value. + OptFns []func(*secretsmanager.Options) + } + // GetSecretValue holds details about calls to the GetSecretValue method. + GetSecretValue []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *secretsmanager.GetSecretValueInput + // OptFns is the optFns argument value. + OptFns []func(*secretsmanager.Options) + } + // ListSecretVersionIds holds details about calls to the ListSecretVersionIds method. + ListSecretVersionIds []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *secretsmanager.ListSecretVersionIdsInput + // OptFns is the optFns argument value. + OptFns []func(*secretsmanager.Options) + } + // PutSecretValue holds details about calls to the PutSecretValue method. + PutSecretValue []struct { + // Ctx is the ctx argument value. + Ctx context.Context + // Params is the params argument value. + Params *secretsmanager.PutSecretValueInput + // OptFns is the optFns argument value. + OptFns []func(*secretsmanager.Options) + } + } + lockCreateSecret sync.RWMutex + lockDescribeSecret sync.RWMutex + lockGetSecretValue sync.RWMutex + lockListSecretVersionIds sync.RWMutex + lockPutSecretValue sync.RWMutex +} + +// CreateSecret calls CreateSecretFunc. +func (mock *apiSecretsManagerMock) CreateSecret(ctx context.Context, params *secretsmanager.CreateSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.CreateSecretOutput, error) { + if mock.CreateSecretFunc == nil { + panic("apiSecretsManagerMock.CreateSecretFunc: method is nil but apiSecretsManager.CreateSecret was just called") + } + callInfo := struct { + Ctx context.Context + Params *secretsmanager.CreateSecretInput + OptFns []func(*secretsmanager.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockCreateSecret.Lock() + mock.calls.CreateSecret = append(mock.calls.CreateSecret, callInfo) + mock.lockCreateSecret.Unlock() + return mock.CreateSecretFunc(ctx, params, optFns...) +} + +// CreateSecretCalls gets all the calls that were made to CreateSecret. +// Check the length with: +// +// len(mockedapiSecretsManager.CreateSecretCalls()) +func (mock *apiSecretsManagerMock) CreateSecretCalls() []struct { + Ctx context.Context + Params *secretsmanager.CreateSecretInput + OptFns []func(*secretsmanager.Options) +} { + var calls []struct { + Ctx context.Context + Params *secretsmanager.CreateSecretInput + OptFns []func(*secretsmanager.Options) + } + mock.lockCreateSecret.RLock() + calls = mock.calls.CreateSecret + mock.lockCreateSecret.RUnlock() + return calls +} + +// DescribeSecret calls DescribeSecretFunc. +func (mock *apiSecretsManagerMock) DescribeSecret(ctx context.Context, params *secretsmanager.DescribeSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.DescribeSecretOutput, error) { + if mock.DescribeSecretFunc == nil { + panic("apiSecretsManagerMock.DescribeSecretFunc: method is nil but apiSecretsManager.DescribeSecret was just called") + } + callInfo := struct { + Ctx context.Context + Params *secretsmanager.DescribeSecretInput + OptFns []func(*secretsmanager.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockDescribeSecret.Lock() + mock.calls.DescribeSecret = append(mock.calls.DescribeSecret, callInfo) + mock.lockDescribeSecret.Unlock() + return mock.DescribeSecretFunc(ctx, params, optFns...) +} + +// DescribeSecretCalls gets all the calls that were made to DescribeSecret. +// Check the length with: +// +// len(mockedapiSecretsManager.DescribeSecretCalls()) +func (mock *apiSecretsManagerMock) DescribeSecretCalls() []struct { + Ctx context.Context + Params *secretsmanager.DescribeSecretInput + OptFns []func(*secretsmanager.Options) +} { + var calls []struct { + Ctx context.Context + Params *secretsmanager.DescribeSecretInput + OptFns []func(*secretsmanager.Options) + } + mock.lockDescribeSecret.RLock() + calls = mock.calls.DescribeSecret + mock.lockDescribeSecret.RUnlock() + return calls +} + +// GetSecretValue calls GetSecretValueFunc. +func (mock *apiSecretsManagerMock) GetSecretValue(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + if mock.GetSecretValueFunc == nil { + panic("apiSecretsManagerMock.GetSecretValueFunc: method is nil but apiSecretsManager.GetSecretValue was just called") + } + callInfo := struct { + Ctx context.Context + Params *secretsmanager.GetSecretValueInput + OptFns []func(*secretsmanager.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockGetSecretValue.Lock() + mock.calls.GetSecretValue = append(mock.calls.GetSecretValue, callInfo) + mock.lockGetSecretValue.Unlock() + return mock.GetSecretValueFunc(ctx, params, optFns...) +} + +// GetSecretValueCalls gets all the calls that were made to GetSecretValue. +// Check the length with: +// +// len(mockedapiSecretsManager.GetSecretValueCalls()) +func (mock *apiSecretsManagerMock) GetSecretValueCalls() []struct { + Ctx context.Context + Params *secretsmanager.GetSecretValueInput + OptFns []func(*secretsmanager.Options) +} { + var calls []struct { + Ctx context.Context + Params *secretsmanager.GetSecretValueInput + OptFns []func(*secretsmanager.Options) + } + mock.lockGetSecretValue.RLock() + calls = mock.calls.GetSecretValue + mock.lockGetSecretValue.RUnlock() + return calls +} + +// ListSecretVersionIds calls ListSecretVersionIdsFunc. +func (mock *apiSecretsManagerMock) ListSecretVersionIds(ctx context.Context, params *secretsmanager.ListSecretVersionIdsInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.ListSecretVersionIdsOutput, error) { + if mock.ListSecretVersionIdsFunc == nil { + panic("apiSecretsManagerMock.ListSecretVersionIdsFunc: method is nil but apiSecretsManager.ListSecretVersionIds was just called") + } + callInfo := struct { + Ctx context.Context + Params *secretsmanager.ListSecretVersionIdsInput + OptFns []func(*secretsmanager.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockListSecretVersionIds.Lock() + mock.calls.ListSecretVersionIds = append(mock.calls.ListSecretVersionIds, callInfo) + mock.lockListSecretVersionIds.Unlock() + return mock.ListSecretVersionIdsFunc(ctx, params, optFns...) +} + +// ListSecretVersionIdsCalls gets all the calls that were made to ListSecretVersionIds. +// Check the length with: +// +// len(mockedapiSecretsManager.ListSecretVersionIdsCalls()) +func (mock *apiSecretsManagerMock) ListSecretVersionIdsCalls() []struct { + Ctx context.Context + Params *secretsmanager.ListSecretVersionIdsInput + OptFns []func(*secretsmanager.Options) +} { + var calls []struct { + Ctx context.Context + Params *secretsmanager.ListSecretVersionIdsInput + OptFns []func(*secretsmanager.Options) + } + mock.lockListSecretVersionIds.RLock() + calls = mock.calls.ListSecretVersionIds + mock.lockListSecretVersionIds.RUnlock() + return calls +} + +// PutSecretValue calls PutSecretValueFunc. +func (mock *apiSecretsManagerMock) PutSecretValue(ctx context.Context, params *secretsmanager.PutSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.PutSecretValueOutput, error) { + if mock.PutSecretValueFunc == nil { + panic("apiSecretsManagerMock.PutSecretValueFunc: method is nil but apiSecretsManager.PutSecretValue was just called") + } + callInfo := struct { + Ctx context.Context + Params *secretsmanager.PutSecretValueInput + OptFns []func(*secretsmanager.Options) + }{ + Ctx: ctx, + Params: params, + OptFns: optFns, + } + mock.lockPutSecretValue.Lock() + mock.calls.PutSecretValue = append(mock.calls.PutSecretValue, callInfo) + mock.lockPutSecretValue.Unlock() + return mock.PutSecretValueFunc(ctx, params, optFns...) +} + +// PutSecretValueCalls gets all the calls that were made to PutSecretValue. +// Check the length with: +// +// len(mockedapiSecretsManager.PutSecretValueCalls()) +func (mock *apiSecretsManagerMock) PutSecretValueCalls() []struct { + Ctx context.Context + Params *secretsmanager.PutSecretValueInput + OptFns []func(*secretsmanager.Options) +} { + var calls []struct { + Ctx context.Context + Params *secretsmanager.PutSecretValueInput + OptFns []func(*secretsmanager.Options) + } + mock.lockPutSecretValue.RLock() + calls = mock.calls.PutSecretValue + mock.lockPutSecretValue.RUnlock() + return calls +} diff --git a/store/s3store.go b/store/s3store.go index 8b093975..e56b1a86 100644 --- a/store/s3store.go +++ b/store/s3store.go @@ -2,18 +2,19 @@ package store import ( "bytes" + "context" "encoding/json" + "errors" "fmt" "io/ioutil" "os" "sort" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/aws-sdk-go-v2/service/sts" ) const ( @@ -52,8 +53,8 @@ type latest struct { var _ Store = &S3Store{} type S3Store struct { - svc s3iface.S3API - stsSvc *sts.STS + svc apiS3 + stsSvc apiSTS bucket string } @@ -68,20 +69,14 @@ func NewS3Store(numRetries int) (*S3Store, error) { } func NewS3StoreWithBucket(numRetries int, bucket string) (*S3Store, error) { - session, region, err := getSession(numRetries) + config, _, err := getConfig(numRetries) if err != nil { return nil, err } - svc := s3.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + svc := s3.NewFromConfig(config) - stsSvc := sts.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + stsSvc := sts.NewFromConfig(config) return &S3Store{ svc: svc, @@ -134,12 +129,12 @@ func (s *S3Store) Write(id SecretId, value string) error { putObjectInput := &s3.PutObjectInput{ Bucket: aws.String(s.bucket), - ServerSideEncryption: aws.String(s3.ServerSideEncryptionAes256), + ServerSideEncryption: types.ServerSideEncryptionAes256, Key: aws.String(objPath), Body: bytes.NewReader(contents), } - _, err = s.svc.PutObject(putObjectInput) + _, err = s.svc.PutObject(context.TODO(), putObjectInput) if err != nil { // TODO: catch specific awserr return err @@ -291,7 +286,7 @@ func (s *S3Store) Delete(id SecretId) error { // so that secret value changes can be correctly attributed to the right // aws user/role func (s *S3Store) getCurrentUser() (string, error) { - resp, err := s.stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + resp, err := s.stsSvc.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}) if err != nil { return "", err } @@ -310,7 +305,7 @@ func (s *S3Store) deleteObject(path string) error { Key: aws.String(path), } - _, err := s.svc.DeleteObject(deleteObjectInput) + _, err := s.svc.DeleteObject(context.TODO(), deleteObjectInput) return err } @@ -320,18 +315,16 @@ func (s *S3Store) readObject(path string) (secretObject, bool, error) { Key: aws.String(path), } - resp, err := s.svc.GetObject(getObjectInput) + resp, err := s.svc.GetObject(context.TODO(), getObjectInput) if err != nil { - // handle aws errors - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case s3.ErrCodeNoSuchBucket: - return secretObject{}, false, err - case s3.ErrCodeNoSuchKey: - return secretObject{}, false, nil - default: - return secretObject{}, false, err - } + // handle specific AWS errors + var nsb *types.NoSuchBucket + if errors.As(err, &nsb) { + return secretObject{}, false, err + } + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + return secretObject{}, false, nil } // generic errors return secretObject{}, false, err @@ -359,12 +352,12 @@ func (s *S3Store) readObjectById(id SecretId) (secretObject, bool, error) { func (s *S3Store) puts3raw(path string, contents []byte) error { putObjectInput := &s3.PutObjectInput{ Bucket: aws.String(s.bucket), - ServerSideEncryption: aws.String(s3.ServerSideEncryptionAes256), + ServerSideEncryption: types.ServerSideEncryptionAes256, Key: aws.String(path), Body: bytes.NewReader(contents), } - _, err := s.svc.PutObject(putObjectInput) + _, err := s.svc.PutObject(context.TODO(), putObjectInput) return err } @@ -376,13 +369,12 @@ func (s *S3Store) readLatest(service string) (latest, error) { Key: aws.String(path), } - resp, err := s.svc.GetObject(getObjectInput) + resp, err := s.svc.GetObject(context.TODO(), getObjectInput) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - if aerr.Code() == s3.ErrCodeNoSuchKey { - // Index doesn't exist yet, return an empty index - return latest{Latest: map[string]string{}}, nil - } + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + // Index doesn't exist yet, return an empty index + return latest{Latest: map[string]string{}}, nil } return latest{}, err } diff --git a/store/s3storeKMS.go b/store/s3storeKMS.go index 10891d65..416e8d76 100644 --- a/store/s3storeKMS.go +++ b/store/s3storeKMS.go @@ -2,6 +2,7 @@ package store import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -9,11 +10,11 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/s3" - "github.com/aws/aws-sdk-go/service/s3/s3iface" - "github.com/aws/aws-sdk-go/service/sts" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/smithy-go" ) // latest is used to keep a single object in s3 with all of the @@ -34,27 +35,21 @@ var _ Store = &S3KMSStore{} type S3KMSStore struct { S3Store - svc s3iface.S3API - stsSvc *sts.STS + svc apiS3 + stsSvc apiSTS bucket string kmsKeyAlias string } func NewS3KMSStore(numRetries int, bucket string, kmsKeyAlias string) (*S3KMSStore, error) { - session, region, err := getSession(numRetries) + config, _, err := getConfig(numRetries) if err != nil { return nil, err } - svc := s3.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + svc := s3.NewFromConfig(config) - stsSvc := sts.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + stsSvc := sts.NewFromConfig(config) if kmsKeyAlias == "" { kmsKeyAlias = DefaultKeyID @@ -123,13 +118,13 @@ func (s *S3KMSStore) Write(id SecretId, value string) error { putObjectInput := &s3.PutObjectInput{ Bucket: aws.String(s.bucket), - ServerSideEncryption: aws.String(s3.ServerSideEncryptionAwsKms), + ServerSideEncryption: types.ServerSideEncryptionAwsKms, SSEKMSKeyId: aws.String(s.kmsKeyAlias), Key: aws.String(objPath), Body: bytes.NewReader(contents), } - _, err = s.svc.PutObject(putObjectInput) + _, err = s.svc.PutObject(context.TODO(), putObjectInput) if err != nil { // TODO: catch specific awserr return err @@ -237,18 +232,16 @@ func (s *S3KMSStore) readObject(path string) (secretObject, bool, error) { Key: aws.String(path), } - resp, err := s.svc.GetObject(getObjectInput) + resp, err := s.svc.GetObject(context.TODO(), getObjectInput) if err != nil { - // handle aws errors - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case s3.ErrCodeNoSuchBucket: - return secretObject{}, false, err - case s3.ErrCodeNoSuchKey: - return secretObject{}, false, nil - default: - return secretObject{}, false, err - } + // handle specific AWS errors + var nsb *types.NoSuchBucket + if errors.As(err, &nsb) { + return secretObject{}, false, err + } + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + return secretObject{}, false, nil } // generic errors return secretObject{}, false, err @@ -271,13 +264,13 @@ func (s *S3KMSStore) readObject(path string) (secretObject, bool, error) { func (s *S3KMSStore) puts3raw(path string, contents []byte) error { putObjectInput := &s3.PutObjectInput{ Bucket: aws.String(s.bucket), - ServerSideEncryption: aws.String(s3.ServerSideEncryptionAwsKms), + ServerSideEncryption: types.ServerSideEncryptionAwsKms, SSEKMSKeyId: aws.String(s.kmsKeyAlias), Key: aws.String(path), Body: bytes.NewReader(contents), } - _, err := s.svc.PutObject(putObjectInput) + _, err := s.svc.PutObject(context.TODO(), putObjectInput) return err } @@ -287,20 +280,21 @@ func (s *S3KMSStore) readLatestFile(path string) (LatestIndexFile, error) { Key: aws.String(path), } - resp, err := s.svc.GetObject(getObjectInput) + resp, err := s.svc.GetObject(context.TODO(), getObjectInput) if err != nil { - if aerr, ok := err.(awserr.Error); ok { - if aerr.Code() == "AccessDenied" { + var nsk *types.NoSuchKey + if errors.As(err, &nsk) { + // Index doesn't exist yet, return an empty index + return LatestIndexFile{Latest: map[string]LatestValue{}}, nil + } + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + if apiErr.ErrorCode() == "AccessDenied" { // If we're not able to read the latest index for a KMS Key then proceed like it doesn't exist. // We do this because in a chamber secret folder there might be other secrets written with a KMS Key that you don't have access to. return LatestIndexFile{Latest: map[string]LatestValue{}}, nil } - - if aerr.Code() == s3.ErrCodeNoSuchKey { - // Index doesn't exist yet, return an empty index - return LatestIndexFile{Latest: map[string]LatestValue{}}, nil - } } return LatestIndexFile{}, err } @@ -323,21 +317,26 @@ func (s *S3KMSStore) readLatest(service string) (LatestIndexFile, error) { latestResult := LatestIndexFile{Latest: map[string]LatestValue{}} // List all the files that are prefixed with kms and use them as latest.json files for that KMS Key. - params := &s3.ListObjectsInput{ + params := &s3.ListObjectsV2Input{ Bucket: aws.String(s.bucket), Prefix: aws.String(fmt.Sprintf("%s/__kms", service)), } var paginationError error + paginator := s3.NewListObjectsV2Paginator(s.svc, params) + for paginator.HasMorePages() { + page, err := paginator.NextPage(context.TODO()) + if err != nil { + return latestResult, err + } - err := s.svc.ListObjectsPages(params, func(page *s3.ListObjectsOutput, lastPage bool) bool { for index := range page.Contents { key_name := *page.Contents[index].Key result, err := s.readLatestFile(key_name) if err != nil { paginationError = errors.New(fmt.Sprintf("Error reading latest index for KMS Key (%s): %s", key_name, err)) - return false + break } // Check if the chamber key already exists in the index.Latest map. @@ -354,18 +353,12 @@ func (s *S3KMSStore) readLatest(service string) (LatestIndexFile, error) { } } } - - return !lastPage - }) + } if paginationError != nil { return latestResult, paginationError } - if err != nil { - return latestResult, err - } - return latestResult, nil } diff --git a/store/secretsmanagerstore.go b/store/secretsmanagerstore.go index dc53073a..51488196 100644 --- a/store/secretsmanagerstore.go +++ b/store/secretsmanagerstore.go @@ -2,19 +2,19 @@ package store import ( + "context" "encoding/json" + "errors" "fmt" "reflect" "sort" "strconv" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" + "github.com/aws/aws-sdk-go-v2/service/sts" ) // We store all Chamber metadata in a stringified JSON format, @@ -74,30 +74,26 @@ var _ Store = &SecretsManagerStore{} // SecretsManagerStore implements the Store interface for storing secrets in SSM Parameter // Store type SecretsManagerStore struct { - svc secretsmanageriface.SecretsManagerAPI - stsSvc stsiface.STSAPI + svc apiSecretsManager + stsSvc apiSTS + config aws.Config } // NewSecretsManagerStore creates a new SecretsManagerStore func NewSecretsManagerStore(numRetries int) (*SecretsManagerStore, error) { - session, region, err := getSession(numRetries) + cfg, _, err := getConfig(numRetries) if err != nil { return nil, err } - svc := secretsmanager.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + svc := secretsmanager.NewFromConfig(cfg) - stsSvc := sts.New(session, &aws.Config{ - MaxRetries: aws.Int(numRetries), - Region: region, - }) + stsSvc := sts.NewFromConfig(cfg) return &SecretsManagerStore{ svc: svc, stsSvc: stsSvc, + config: cfg, }, nil } @@ -121,12 +117,9 @@ func (s *SecretsManagerStore) Write(id SecretId, value string) error { return err } if err != ErrSecretNotFound { - if awsErr, ok := err.(awserr.Error); ok { - if awsErr.Code() == secretsmanager.ErrCodeResourceNotFoundException { - mustCreate = true - } else { - return err - } + var rnfe *types.ResourceNotFoundException + if errors.As(err, &rnfe) { + mustCreate = true } else { return err } @@ -192,7 +185,7 @@ func (s *SecretsManagerStore) Write(id SecretId, value string) error { Name: aws.String(id.Service), SecretString: aws.String(string(contents)), } - _, err = s.svc.CreateSecret(createSecretValueInput) + _, err = s.svc.CreateSecret(context.TODO(), createSecretValueInput) if err != nil { return err } @@ -202,20 +195,20 @@ func (s *SecretsManagerStore) Write(id SecretId, value string) error { describeSecretInput := &secretsmanager.DescribeSecretInput{ SecretId: aws.String(id.Service), } - details, err := s.svc.DescribeSecret(describeSecretInput) + details, err := s.svc.DescribeSecret(context.TODO(), describeSecretInput) if err != nil { return err } - if aws.BoolValue(details.RotationEnabled) { + if details.RotationEnabled != nil && *details.RotationEnabled { return fmt.Errorf("Cannot write to a secret with rotation enabled") } putSecretValueInput := &secretsmanager.PutSecretValueInput{ SecretId: aws.String(id.Service), SecretString: aws.String(string(contents)), - VersionStages: []*string{aws.String("AWSCURRENT"), aws.String("CHAMBER" + fmt.Sprint(version))}, + VersionStages: []string{"AWSCURRENT", "CHAMBER" + fmt.Sprint(version)}, } - _, err = s.svc.PutSecretValue(putSecretValueInput) + _, err = s.svc.PutSecretValue(context.TODO(), putSecretValueInput) if err != nil { return err } @@ -270,7 +263,7 @@ func (s *SecretsManagerStore) readVersion(id SecretId, version int) (Secret, err } var result Secret - resp, err := s.svc.ListSecretVersionIds(listSecretVersionIdsInput) + resp, err := s.svc.ListSecretVersionIds(context.TODO(), listSecretVersionIdsInput) if err != nil { return Secret{}, err } @@ -284,7 +277,7 @@ func (s *SecretsManagerStore) readVersion(id SecretId, version int) (Secret, err VersionId: h.VersionId, } - resp, err := s.svc.GetSecretValue(getSecretValueInput) + resp, err := s.svc.GetSecretValue(context.TODO(), getSecretValueInput) if err != nil { return Secret{}, err @@ -336,7 +329,7 @@ func (s *SecretsManagerStore) readLatest(service string) (secretValueObject, err SecretId: aws.String(service), } - resp, err := s.svc.GetSecretValue(getSecretValueInput) + resp, err := s.svc.GetSecretValue(context.TODO(), getSecretValueInput) if err != nil { return secretValueObject{}, err @@ -435,7 +428,7 @@ func (s *SecretsManagerStore) History(id SecretId) ([]ChangeEvent, error) { IncludeDeprecated: aws.Bool(false), } - resp, err := s.svc.ListSecretVersionIds(listSecretVersionIdsInput) + resp, err := s.svc.ListSecretVersionIds(context.TODO(), listSecretVersionIdsInput) if err != nil { return events, err } @@ -452,7 +445,7 @@ func (s *SecretsManagerStore) History(id SecretId) ([]ChangeEvent, error) { VersionId: h.VersionId, } - resp, err := s.svc.GetSecretValue(getSecretValueInput) + resp, err := s.svc.GetSecretValue(context.TODO(), getSecretValueInput) if err != nil { return events, err @@ -506,7 +499,7 @@ func (s *SecretsManagerStore) History(id SecretId) ([]ChangeEvent, error) { } func (s *SecretsManagerStore) getCurrentUser() (string, error) { - resp, err := s.stsSvc.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + resp, err := s.stsSvc.GetCallerIdentity(context.TODO(), &sts.GetCallerIdentityInput{}) if err != nil { return "", err } diff --git a/store/secretsmanagerstore_test.go b/store/secretsmanagerstore_test.go index 53777821..2d3d32ef 100644 --- a/store/secretsmanagerstore_test.go +++ b/store/secretsmanagerstore_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "crypto/rand" "encoding/json" "fmt" @@ -9,29 +10,20 @@ import ( "sort" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/service/secretsmanager" - "github.com/aws/aws-sdk-go/service/secretsmanager/secretsmanageriface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager" + "github.com/aws/aws-sdk-go-v2/service/secretsmanager/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/stretchr/testify/assert" ) -type mockSecretsManagerClient struct { - secretsmanageriface.SecretsManagerAPI - secrets map[string]mockSecret - outputs map[string]secretsmanager.DescribeSecretOutput -} - type mockSecret struct { currentSecret *secretValueObject history map[string]*secretValueObject } -func (m *mockSecretsManagerClient) PutSecretValue(i *secretsmanager.PutSecretValueInput) (*secretsmanager.PutSecretValueOutput, error) { - current, ok := m.secrets[*i.SecretId] +func mockPutSecretValue(i *secretsmanager.PutSecretValueInput, secrets map[string]mockSecret) (*secretsmanager.PutSecretValueOutput, error) { + current, ok := secrets[*i.SecretId] if !ok { return &secretsmanager.PutSecretValueOutput{}, ErrSecretNotFound } @@ -44,12 +36,12 @@ func (m *mockSecretsManagerClient) PutSecretValue(i *secretsmanager.PutSecretVal current.currentSecret = &secret current.history[uniqueID()] = &secret - m.secrets[*i.SecretId] = current + secrets[*i.SecretId] = current return &secretsmanager.PutSecretValueOutput{}, nil } -func (m *mockSecretsManagerClient) CreateSecret(i *secretsmanager.CreateSecretInput) (*secretsmanager.CreateSecretOutput, error) { +func mockCreateSecret(i *secretsmanager.CreateSecretInput, secrets map[string]mockSecret) (*secretsmanager.CreateSecretOutput, error) { secret, err := jsonToSecretValueObject(*i.SecretString) if err != nil { return &secretsmanager.CreateSecretOutput{}, err @@ -61,24 +53,30 @@ func (m *mockSecretsManagerClient) CreateSecret(i *secretsmanager.CreateSecretIn } current.history[uniqueID()] = &secret - m.secrets[*i.Name] = current + secrets[*i.Name] = current return &secretsmanager.CreateSecretOutput{}, nil } -func (m *mockSecretsManagerClient) GetSecretValue(i *secretsmanager.GetSecretValueInput) (*secretsmanager.GetSecretValueOutput, error) { +func mockGetSecretValue(i *secretsmanager.GetSecretValueInput, secrets map[string]mockSecret) (*secretsmanager.GetSecretValueOutput, error) { var version *secretValueObject if i.VersionId != nil { - historyItem, ok := m.secrets[*i.SecretId].history[*i.VersionId] + historyItem, ok := secrets[*i.SecretId].history[*i.VersionId] if !ok { - return &secretsmanager.GetSecretValueOutput{}, awserr.New(secretsmanager.ErrCodeResourceNotFoundException, secretsmanager.ErrCodeResourceNotFoundException, ErrSecretNotFound) + return &secretsmanager.GetSecretValueOutput{}, + &types.ResourceNotFoundException{ + Message: aws.String("ResourceNotFoundException"), + } } version = historyItem } else { - current, ok := m.secrets[*i.SecretId] + current, ok := secrets[*i.SecretId] if !ok { - return &secretsmanager.GetSecretValueOutput{}, awserr.New(secretsmanager.ErrCodeResourceNotFoundException, secretsmanager.ErrCodeResourceNotFoundException, ErrSecretNotFound) + return &secretsmanager.GetSecretValueOutput{}, + &types.ResourceNotFoundException{ + Message: aws.String("ResourceNotFoundException"), + } } version = current.currentSecret } @@ -93,43 +91,58 @@ func (m *mockSecretsManagerClient) GetSecretValue(i *secretsmanager.GetSecretVal }, nil } -func (m *mockSecretsManagerClient) ListSecretVersionIds(i *secretsmanager.ListSecretVersionIdsInput) (*secretsmanager.ListSecretVersionIdsOutput, error) { - service, ok := m.secrets[*i.SecretId] +func mockListSecretVersionIds(i *secretsmanager.ListSecretVersionIdsInput, secrets map[string]mockSecret) (*secretsmanager.ListSecretVersionIdsOutput, error) { + service, ok := secrets[*i.SecretId] if !ok || len(service.history) == 0 { return &secretsmanager.ListSecretVersionIdsOutput{}, ErrSecretNotFound } - Versions := make([]*secretsmanager.SecretVersionsListEntry, 0) + versions := make([]types.SecretVersionsListEntry, 0) for v := range service.history { - Versions = append(Versions, &secretsmanager.SecretVersionsListEntry{VersionId: aws.String(v)}) + versions = append(versions, types.SecretVersionsListEntry{VersionId: aws.String(v)}) } - return &secretsmanager.ListSecretVersionIdsOutput{Versions: Versions}, nil + return &secretsmanager.ListSecretVersionIdsOutput{Versions: versions}, nil } -func (m *mockSecretsManagerClient) DescribeSecret(i *secretsmanager.DescribeSecretInput) (*secretsmanager.DescribeSecretOutput, error) { - output, ok := m.outputs[*i.SecretId] +func mockDescribeSecret(i *secretsmanager.DescribeSecretInput, outputs map[string]secretsmanager.DescribeSecretOutput) (*secretsmanager.DescribeSecretOutput, error) { + output, ok := outputs[*i.SecretId] if !ok { return &secretsmanager.DescribeSecretOutput{RotationEnabled: aws.Bool(false)}, nil } return &output, nil } -type mockSTSClient struct { - stsiface.STSAPI -} - -func (s *mockSTSClient) GetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { +func mockGetCallerIdentity(_ *sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Arn: aws.String("currentuser"), }, nil } -func NewTestSecretsManagerStore(mock secretsmanageriface.SecretsManagerAPI) *SecretsManagerStore { - stsSvc := &mockSTSClient{} +func NewTestSecretsManagerStore(secrets map[string]mockSecret, outputs map[string]secretsmanager.DescribeSecretOutput) *SecretsManagerStore { return &SecretsManagerStore{ - svc: mock, - stsSvc: stsSvc, + svc: &apiSecretsManagerMock{ + CreateSecretFunc: func(ctx context.Context, params *secretsmanager.CreateSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.CreateSecretOutput, error) { + return mockCreateSecret(params, secrets) + }, + DescribeSecretFunc: func(ctx context.Context, params *secretsmanager.DescribeSecretInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.DescribeSecretOutput, error) { + return mockDescribeSecret(params, outputs) + }, + GetSecretValueFunc: func(ctx context.Context, params *secretsmanager.GetSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.GetSecretValueOutput, error) { + return mockGetSecretValue(params, secrets) + }, + ListSecretVersionIdsFunc: func(ctx context.Context, params *secretsmanager.ListSecretVersionIdsInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.ListSecretVersionIdsOutput, error) { + return mockListSecretVersionIds(params, secrets) + }, + PutSecretValueFunc: func(ctx context.Context, params *secretsmanager.PutSecretValueInput, optFns ...func(*secretsmanager.Options)) (*secretsmanager.PutSecretValueOutput, error) { + return mockPutSecretValue(params, secrets) + }, + }, + stsSvc: &apiSTSMock{ + GetCallerIdentityFunc: func(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { + return mockGetCallerIdentity(params) + }, + }, } } @@ -161,63 +174,62 @@ func TestSecretValueObjectUnmarshalling(t *testing.T) { func TestNewSecretsManagerStore(t *testing.T) { t.Run("Using region override should take precedence over other settings", func(t *testing.T) { os.Setenv("CHAMBER_AWS_REGION", "us-east-1") + defer os.Unsetenv("CHAMBER_AWS_REGION") os.Setenv("AWS_REGION", "us-west-1") + defer os.Unsetenv("AWS_REGION") os.Setenv("AWS_DEFAULT_REGION", "us-west-2") + defer os.Unsetenv("AWS_DEFAULT_REGION") s, err := NewSecretsManagerStore(1) assert.Nil(t, err) - assert.Equal(t, "us-east-1", aws.StringValue(s.svc.(*secretsmanager.SecretsManager).Config.Region)) - os.Unsetenv("CHAMBER_AWS_REGION") - os.Unsetenv("AWS_REGION") - os.Unsetenv("AWS_DEFAULT_REGION") + assert.Equal(t, "us-east-1", s.config.Region) }) t.Run("Should use AWS_REGION if it is set", func(t *testing.T) { os.Setenv("AWS_REGION", "us-west-1") + defer os.Unsetenv("AWS_REGION") s, err := NewSecretsManagerStore(1) assert.Nil(t, err) - assert.Equal(t, "us-west-1", aws.StringValue(s.svc.(*secretsmanager.SecretsManager).Config.Region)) - - os.Unsetenv("AWS_REGION") + assert.Equal(t, "us-west-1", s.config.Region) }) t.Run("Should use CHAMBER_AWS_SSM_ENDPOINT if set", func(t *testing.T) { os.Setenv("CHAMBER_AWS_SSM_ENDPOINT", "mycustomendpoint") + defer os.Unsetenv("CHAMBER_AWS_SSM_ENDPOINT") s, err := NewSecretsManagerStore(1) assert.Nil(t, err) - endpoint, err := s.svc.(*secretsmanager.SecretsManager).Config.EndpointResolver.EndpointFor(endpoints.SecretsmanagerServiceID, endpoints.UsWest2RegionID) + endpoint, err := s.config.EndpointResolverWithOptions.ResolveEndpoint(secretsmanager.ServiceID, "us-west-2") assert.Nil(t, err) assert.Equal(t, "mycustomendpoint", endpoint.URL) - - os.Unsetenv("CHAMBER_AWS_SSM_ENDPOINT") }) t.Run("Should use default AWS SSM endpoint if CHAMBER_AWS_SSM_ENDPOINT not set", func(t *testing.T) { s, err := NewSecretsManagerStore(1) assert.Nil(t, err) - endpoint, err := s.svc.(*secretsmanager.SecretsManager).Config.EndpointResolver.EndpointFor(endpoints.SecretsmanagerServiceID, endpoints.UsWest2RegionID) - assert.Nil(t, err) - assert.Equal(t, "https://secretsmanager.us-west-2.amazonaws.com", endpoint.URL) + _, err = s.config.EndpointResolverWithOptions.ResolveEndpoint(secretsmanager.ServiceID, "us-west-2") + var notFoundError *aws.EndpointNotFoundError + assert.ErrorAs(t, err, ¬FoundError) }) } func TestSecretsManagerWrite(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}, outputs: map[string]secretsmanager.DescribeSecretOutput{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) t.Run("Setting a new key should work", func(t *testing.T) { key := "mykey" secretId := SecretId{Service: "test", Key: key} err := store.Write(secretId, "value") assert.Nil(t, err) - assert.Contains(t, mock.secrets, secretId.Service) - assert.Equal(t, "value", (*mock.secrets[secretId.Service].currentSecret)[key]) - keyMetadata, err := getHydratedKeyMetadata(mock.secrets[secretId.Service].currentSecret, &key) + assert.Contains(t, secrets, secretId.Service) + assert.Equal(t, "value", (*secrets[secretId.Service].currentSecret)[key]) + keyMetadata, err := getHydratedKeyMetadata(secrets[secretId.Service].currentSecret, &key) assert.Nil(t, err) assert.Equal(t, 1, keyMetadata.Version) - assert.Equal(t, 1, len(mock.secrets[secretId.Service].history)) + assert.Equal(t, 1, len(secrets[secretId.Service].history)) }) t.Run("Setting a key twice should create a new version", func(t *testing.T) { @@ -225,27 +237,27 @@ func TestSecretsManagerWrite(t *testing.T) { secretId := SecretId{Service: "test", Key: key} err := store.Write(secretId, "value") assert.Nil(t, err) - assert.Contains(t, mock.secrets, secretId.Service) - assert.Equal(t, "value", (*mock.secrets[secretId.Service].currentSecret)[key]) - keyMetadata, err := getHydratedKeyMetadata(mock.secrets[secretId.Service].currentSecret, &key) + assert.Contains(t, secrets, secretId.Service) + assert.Equal(t, "value", (*secrets[secretId.Service].currentSecret)[key]) + keyMetadata, err := getHydratedKeyMetadata(secrets[secretId.Service].currentSecret, &key) assert.Nil(t, err) assert.Equal(t, 1, keyMetadata.Version) - assert.Equal(t, 2, len(mock.secrets[secretId.Service].history)) + assert.Equal(t, 2, len(secrets[secretId.Service].history)) err = store.Write(secretId, "newvalue") assert.Nil(t, err) - assert.Contains(t, mock.secrets, secretId.Service) - assert.Equal(t, "newvalue", (*mock.secrets[secretId.Service].currentSecret)[key]) - keyMetadata, err = getHydratedKeyMetadata(mock.secrets[secretId.Service].currentSecret, &key) + assert.Contains(t, secrets, secretId.Service) + assert.Equal(t, "newvalue", (*secrets[secretId.Service].currentSecret)[key]) + keyMetadata, err = getHydratedKeyMetadata(secrets[secretId.Service].currentSecret, &key) assert.Nil(t, err) assert.Equal(t, 2, keyMetadata.Version) - assert.Equal(t, 3, len(mock.secrets[secretId.Service].history)) + assert.Equal(t, 3, len(secrets[secretId.Service].history)) }) t.Run("Setting a key on a secret with rotation enabled should fail", func(t *testing.T) { service := "rotationtest" - mock.secrets[service] = mockSecret{} - mock.outputs[service] = secretsmanager.DescribeSecretOutput{RotationEnabled: aws.Bool(true)} + secrets[service] = mockSecret{} + outputs[service] = secretsmanager.DescribeSecretOutput{RotationEnabled: aws.Bool(true)} secretId := SecretId{Service: service, Key: "doesnotmatter"} err := store.Write(secretId, "value") assert.EqualError(t, err, "Cannot write to a secret with rotation enabled") @@ -253,8 +265,9 @@ func TestSecretsManagerWrite(t *testing.T) { } func TestSecretsManagerRead(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) secretId := SecretId{Service: "test", Key: "key"} store.Write(secretId, "value") store.Write(secretId, "second value") @@ -292,15 +305,16 @@ func TestSecretsManagerRead(t *testing.T) { } func TestSecretsManagerList(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) - secrets := []SecretId{ + testSecrets := []SecretId{ {Service: "test", Key: "a"}, {Service: "test", Key: "b"}, {Service: "test", Key: "c"}, } - for _, secret := range secrets { + for _, secret := range testSecrets { store.Write(secret, "value") } @@ -342,15 +356,16 @@ func TestSecretsManagerList(t *testing.T) { } func TestSecretsManagerListRaw(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) - secrets := []SecretId{ + testSecrets := []SecretId{ {Service: "test", Key: "a"}, {Service: "test", Key: "b"}, {Service: "test", Key: "c"}, } - for _, secret := range secrets { + for _, secret := range testSecrets { store.Write(secret, "value") } @@ -383,17 +398,18 @@ func TestSecretsManagerListRaw(t *testing.T) { } func TestSecretsManagerHistory(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) - secrets := []SecretId{ + testSecrets := []SecretId{ {Service: "test", Key: "new"}, {Service: "test", Key: "update"}, {Service: "test", Key: "update"}, {Service: "test", Key: "update"}, } - for _, s := range secrets { + for _, s := range testSecrets { store.Write(s, "value") } @@ -420,8 +436,9 @@ func TestSecretsManagerHistory(t *testing.T) { } func TestSecretsManagerDelete(t *testing.T) { - mock := &mockSecretsManagerClient{secrets: map[string]mockSecret{}} - store := NewTestSecretsManagerStore(mock) + secrets := make(map[string]mockSecret) + outputs := make(map[string]secretsmanager.DescribeSecretOutput) + store := NewTestSecretsManagerStore(secrets, outputs) secretId := SecretId{Service: "test", Key: "key"} store.Write(secretId, "value") diff --git a/store/shared.go b/store/shared.go index 39ce1651..2f6be6a8 100644 --- a/store/shared.go +++ b/store/shared.go @@ -1,12 +1,12 @@ package store import ( + "context" "os" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" ) const ( @@ -14,47 +14,45 @@ const ( CustomSSMEndpointEnvVar = "CHAMBER_AWS_SSM_ENDPOINT" ) -func getSession(numRetries int) (*session.Session, *string, error) { - var region *string - - endpointResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) { +func getConfig(numRetries int) (aws.Config, string, error) { + endpointResolver := func(service, region string, options ...interface{}) (aws.Endpoint, error) { customSsmEndpoint, ok := os.LookupEnv(CustomSSMEndpointEnvVar) if ok { - return endpoints.ResolvedEndpoint{ + return aws.Endpoint{ URL: customSsmEndpoint, + Source: aws.EndpointSourceCustom, }, nil } - return endpoints.DefaultResolver().EndpointFor(service, region, optFns...) + return aws.Endpoint{}, &aws.EndpointNotFoundError{} } + var region string if regionOverride, ok := os.LookupEnv(RegionEnvVar); ok { - region = aws.String(regionOverride) + region = regionOverride } - retSession, err := session.NewSessionWithOptions( - session.Options{ - Config: aws.Config{ - Region: region, - MaxRetries: aws.Int(numRetries), - EndpointResolver: endpoints.ResolverFunc(endpointResolver), - }, - SharedConfigState: session.SharedConfigEnable, - }, + + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithRegion(region), + config.WithRetryMaxAttempts(numRetries), + config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(endpointResolver)), ) if err != nil { - return nil, nil, err + return aws.Config{}, "", err } // If region is still not set, attempt to determine it via ec2 metadata API - if aws.StringValue(retSession.Config.Region) == "" { - session := session.New() - ec2metadataSvc := ec2metadata.New(session) - if regionOverride, err := ec2metadataSvc.Region(); err == nil { - region = aws.String(regionOverride) + if cfg.Region == "" { + imdsConfig, err := config.LoadDefaultConfig(context.TODO()) + if err != nil { + ec2metadataSvc := imds.NewFromConfig(imdsConfig) + if regionOverride, err := ec2metadataSvc.GetRegion(context.TODO(), &imds.GetRegionInput{}); err == nil { + region = regionOverride.Region + } } } - return retSession, region, nil + return cfg, region, err } func uniqueStringSlice(slice []string) []string { diff --git a/store/ssmstore.go b/store/ssmstore.go index 2cf0cb7d..1c007f82 100644 --- a/store/ssmstore.go +++ b/store/ssmstore.go @@ -1,6 +1,7 @@ package store import ( + "context" "fmt" "os" "regexp" @@ -8,11 +9,9 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" ) const ( @@ -20,7 +19,8 @@ const ( DefaultKeyID = "alias/parameter_store_key" // DefaultMinThrottleDelay is the default delay before retrying throttled requests - DefaultMinThrottleDelay = client.DefaultRetryerMinThrottleDelay + // DefaultMinThrottleDelay = client.DefaultRetryerMinThrottleDelay + DefaultMinThrottleDelay = 0 ) // validPathKeyFormat is the format that is expected for key names inside parameter store @@ -40,7 +40,8 @@ var labelMatchRegex = regexp.MustCompile(`^(\/[\w\-\.]+)+:(.+)$`) // SSMStore implements the Store interface for storing secrets in SSM Parameter // Store type SSMStore struct { - svc ssmiface.SSMAPI + svc apiSSM + config aws.Config usePaths bool } @@ -55,13 +56,18 @@ func NewSSMStoreWithMinThrottleDelay(numRetries int, minThrottleDelay time.Durat } func ssmStoreUsingRetryer(numRetries int, minThrottleDelay time.Duration) (*SSMStore, error) { - ssmSession, region, err := getSession(numRetries) + cfg, _, err := getConfig(numRetries) if err != nil { return nil, err } - retryer := client.DefaultRetryer{NumMaxRetries: numRetries, MinThrottleDelay: minThrottleDelay} + // FIXME minThrottleDelay is ignored + // retryer := retry.NewStandard( + // func(o *retry.StandardOptions) { + // o.MaxAttempts = numRetries + // }, + // ) usePaths := true _, ok := os.LookupEnv("CHAMBER_NO_PATHS") @@ -69,13 +75,11 @@ func ssmStoreUsingRetryer(numRetries int, minThrottleDelay time.Duration) (*SSMS usePaths = false } - svc := ssm.New(ssmSession, &aws.Config{ - Retryer: retryer, - Region: region, - }) + svc := ssm.NewFromConfig(cfg) return &SSMStore{ svc: svc, + config: cfg, usePaths: usePaths, }, nil } @@ -108,14 +112,14 @@ func (s *SSMStore) Write(id SecretId, value string) error { putParameterInput := &ssm.PutParameterInput{ KeyId: aws.String(s.KMSKey()), Name: aws.String(s.idToName(id)), - Type: aws.String("SecureString"), + Type: types.ParameterTypeSecureString, Value: aws.String(value), Overwrite: aws.Bool(true), Description: aws.String(strconv.Itoa(version)), } // This API call returns an empty struct - _, err = s.svc.PutParameter(putParameterInput) + _, err = s.svc.PutParameter(context.TODO(), putParameterInput) if err != nil { return err } @@ -146,7 +150,7 @@ func (s *SSMStore) Delete(id SecretId) error { Name: aws.String(s.idToName(id)), } - _, err = s.svc.DeleteParameter(deleteParameterInput) + _, err = s.svc.DeleteParameter(context.TODO(), deleteParameterInput) if err != nil { return err } @@ -161,7 +165,12 @@ func (s *SSMStore) readVersion(id SecretId, version int) (Secret, error) { } var result Secret - if err := s.svc.GetParameterHistoryPages(getParameterHistoryInput, func(o *ssm.GetParameterHistoryOutput, lastPage bool) bool { + paginator := ssm.NewGetParameterHistoryPaginator(s.svc, getParameterHistoryInput) + for paginator.HasMorePages() { + o, err := paginator.NextPage(context.TODO()) + if err != nil { + return Secret{}, ErrSecretNotFound + } for _, history := range o.Parameters { thisVersion := 0 if history.Description != nil { @@ -177,12 +186,9 @@ func (s *SSMStore) readVersion(id SecretId, version int) (Secret, error) { Key: *history.Name, }, } - return false + break } } - return true - }); err != nil { - return Secret{}, ErrSecretNotFound } if result.Value != nil { return result, nil @@ -193,11 +199,11 @@ func (s *SSMStore) readVersion(id SecretId, version int) (Secret, error) { func (s *SSMStore) readLatest(id SecretId) (Secret, error) { getParametersInput := &ssm.GetParametersInput{ - Names: []*string{aws.String(s.idToName(id))}, + Names: []string{s.idToName(id)}, WithDecryption: aws.Bool(true), } - resp, err := s.svc.GetParameters(getParametersInput) + resp, err := s.svc.GetParameters(context.TODO(), getParametersInput) if err != nil { return Secret{}, err } @@ -206,7 +212,7 @@ func (s *SSMStore) readLatest(id SecretId) (Secret, error) { return Secret{}, ErrSecretNotFound } param := resp.Parameters[0] - var parameter *ssm.ParameterMetadata + var parameter *types.ParameterMetadata var describeParametersInput *ssm.DescribeParametersInput // To get metadata, we need to use describe parameters @@ -216,42 +222,44 @@ func (s *SSMStore) readLatest(id SecretId) (Secret, error) { // if that key uses paths, so instead get all the keys for a path, // then find the one you are looking for :( describeParametersInput = &ssm.DescribeParametersInput{ - ParameterFilters: []*ssm.ParameterStringFilter{ + ParameterFilters: []types.ParameterStringFilter{ { Key: aws.String("Path"), Option: aws.String("OneLevel"), - Values: []*string{aws.String(basePath(s.idToName(id)))}, + Values: []string{basePath(s.idToName(id))}, }, }, } } else { describeParametersInput = &ssm.DescribeParametersInput{ - Filters: []*ssm.ParametersFilter{ + Filters: []types.ParametersFilter{ { - Key: aws.String("Name"), - Values: []*string{aws.String(s.idToName(id))}, + Key: types.ParametersFilterKeyName, + Values: []string{s.idToName(id)}, }, }, - MaxResults: aws.Int64(1), + MaxResults: aws.Int32(1), } } - if err := s.svc.DescribeParametersPages(describeParametersInput, func(o *ssm.DescribeParametersOutput, lastPage bool) bool { + paginator := ssm.NewDescribeParametersPaginator(s.svc, describeParametersInput) + for paginator.HasMorePages() { + o, err := paginator.NextPage(context.TODO()) + if err != nil { + return Secret{}, err + } for _, param := range o.Parameters { if *param.Name == s.idToName(id) { - parameter = param - return false + parameter = ¶m + break } } - return true - }); err != nil { - return Secret{}, err } if parameter == nil { return Secret{}, ErrSecretNotFound } - secretMeta := parameterMetaToSecretMeta(parameter) + secretMeta := parameterMetaToSecretMeta(*parameter) return Secret{ Value: param.Value, @@ -265,28 +273,33 @@ func (s *SSMStore) ListServices(service string, includeSecretName bool) ([]strin if s.usePaths { describeParametersInput = &ssm.DescribeParametersInput{ - MaxResults: aws.Int64(50), - ParameterFilters: []*ssm.ParameterStringFilter{ + MaxResults: aws.Int32(50), + ParameterFilters: []types.ParameterStringFilter{ { Key: aws.String("Name"), Option: aws.String("BeginsWith"), - Values: []*string{aws.String("/" + service)}, + Values: []string{"/" + service}, }, }, } } else { describeParametersInput = &ssm.DescribeParametersInput{ - MaxResults: aws.Int64(50), - Filters: []*ssm.ParametersFilter{ + MaxResults: aws.Int32(50), + Filters: []types.ParametersFilter{ { - Key: aws.String("Name"), - Values: []*string{aws.String(service + ".")}, + Key: types.ParametersFilterKeyName, + Values: []string{service + "."}, }, }, } } - err := s.svc.DescribeParametersPages(describeParametersInput, func(resp *ssm.DescribeParametersOutput, lastPage bool) bool { + paginator := ssm.NewDescribeParametersPaginator(s.svc, describeParametersInput) + for paginator.HasMorePages() { + resp, err := paginator.NextPage(context.TODO()) + if err != nil { + return nil, err + } for _, meta := range resp.Parameters { if !s.validateName(*meta.Name) { continue @@ -297,10 +310,6 @@ func (s *SSMStore) ListServices(service string, includeSecretName bool) ([]strin Meta: secretMeta, } } - return true - }) - if err != nil { - return nil, err } if includeSecretName { @@ -327,26 +336,31 @@ func (s *SSMStore) List(serviceName string, includeValues bool) ([]Secret, error if s.usePaths { describeParametersInput = &ssm.DescribeParametersInput{ - ParameterFilters: []*ssm.ParameterStringFilter{ + ParameterFilters: []types.ParameterStringFilter{ { Key: aws.String("Path"), Option: aws.String("OneLevel"), - Values: []*string{aws.String("/" + service)}, + Values: []string{"/" + service}, }, }, } } else { describeParametersInput = &ssm.DescribeParametersInput{ - Filters: []*ssm.ParametersFilter{ + Filters: []types.ParametersFilter{ { - Key: aws.String("Name"), - Values: []*string{aws.String(service + ".")}, + Key: types.ParametersFilterKeyName, + Values: []string{service + "."}, }, }, } } - err := s.svc.DescribeParametersPages(describeParametersInput, func(resp *ssm.DescribeParametersOutput, lastPage bool) bool { + paginator := ssm.NewDescribeParametersPaginator(s.svc, describeParametersInput) + for paginator.HasMorePages() { + resp, err := paginator.NextPage(context.TODO()) + if err != nil { + return nil, err + } for _, meta := range resp.Parameters { if !s.validateName(*meta.Name) { continue @@ -357,10 +371,6 @@ func (s *SSMStore) List(serviceName string, includeValues bool) ([]Secret, error Meta: secretMeta, } } - return true - }) - if err != nil { - return nil, err } if includeValues { @@ -373,11 +383,11 @@ func (s *SSMStore) List(serviceName string, includeValues bool) ([]Secret, error batch := secretKeys[i:batchEnd] getParametersInput := &ssm.GetParametersInput{ - Names: stringsToAWSStrings(batch), + Names: batch, WithDecryption: aws.Bool(true), } - resp, err := s.svc.GetParameters(getParametersInput) + resp, err := s.svc.GetParameters(context.TODO(), getParametersInput) if err != nil { return nil, err } @@ -405,16 +415,22 @@ func (s *SSMStore) ListRaw(serviceName string) ([]RawSecret, error) { WithDecryption: aws.Bool(true), } if label != "" { - getParametersByPathInput.ParameterFilters = []*ssm.ParameterStringFilter{ + getParametersByPathInput.ParameterFilters = []types.ParameterStringFilter{ { Key: aws.String("Label"), Option: aws.String("Equals"), - Values: []*string{aws.String(label)}, + Values: []string{label}, }, } } - err := s.svc.GetParametersByPathPages(getParametersByPathInput, func(resp *ssm.GetParametersByPathOutput, lastPage bool) bool { + paginator := ssm.NewGetParametersByPathPaginator(s.svc, getParametersByPathInput) + for paginator.HasMorePages() { + resp, err := paginator.NextPage(context.TODO()) + if err != nil { + return nil, err + } + for _, param := range resp.Parameters { if !s.validateName(*param.Name) { continue @@ -425,28 +441,6 @@ func (s *SSMStore) ListRaw(serviceName string) ([]RawSecret, error) { Key: *param.Name, } } - return true - }) - - if err != nil { - // If the error is an access-denied exception - awsErr, isAwserr := err.(awserr.Error) - if isAwserr { - if awsErr.Code() == "AccessDeniedException" && strings.Contains(awsErr.Message(), "is not authorized to perform: ssm:GetParametersByPath on resource") { - // Fall-back to using the old list method in case some users haven't updated their IAM permissions yet, but warn about it and - // tell them to fix their permissions - fmt.Fprintf( - os.Stderr, - "Warning: %s\nFalling-back to using ssm:DescribeParameters. This may cause delays or failures due to AWS rate-limiting.\n"+ - "This is behavior deprecated and will be removed in a future version of chamber. Please update your IAM permissions to grant ssm:GetParametersByPath.\n\n", - awsErr) - - // Delegate to List - return s.listRawViaList(service) - } - } - - return nil, err } rawSecrets := make([]RawSecret, len(secrets)) @@ -458,7 +452,7 @@ func (s *SSMStore) ListRaw(serviceName string) ([]RawSecret, error) { return rawSecrets, nil } - // Delete to List (which uses the DescribeParameters API) + // Delegate to List (which uses the DescribeParameters API) return s.listRawViaList(service) } @@ -472,7 +466,13 @@ func (s *SSMStore) History(id SecretId) ([]ChangeEvent, error) { WithDecryption: aws.Bool(false), } - if err := s.svc.GetParameterHistoryPages(getParameterHistoryInput, func(o *ssm.GetParameterHistoryOutput, lastPage bool) bool { + paginator := ssm.NewGetParameterHistoryPaginator(s.svc, getParameterHistoryInput) + for paginator.HasMorePages() { + o, err := paginator.NextPage(context.TODO()) + if err != nil { + return events, ErrSecretNotFound + } + for _, history := range o.Parameters { // Disregard error here, if Atoi fails (secret created outside of // Chamber), then we use version 0 @@ -487,9 +487,6 @@ func (s *SSMStore) History(id SecretId) ([]ChangeEvent, error) { Version: version, }) } - return true - }); err != nil { - return events, ErrSecretNotFound } return events, nil @@ -550,7 +547,7 @@ func serviceName(key string) string { return strings.Join(pathParts[1:end], "/") } -func parameterMetaToSecretMeta(p *ssm.ParameterMetadata) SecretMetadata { +func parameterMetaToSecretMeta(p types.ParameterMetadata) SecretMetadata { version := 0 if p.Description != nil { version, _ = strconv.Atoi(*p.Description) @@ -579,14 +576,6 @@ func values(m map[string]Secret) []Secret { return values } -func stringsToAWSStrings(slice []string) []*string { - ret := []*string{} - for _, s := range slice { - ret = append(ret, aws.String(s)) - } - return ret -} - func getChangeType(version int) ChangeEventType { if version == 1 { return Created diff --git a/store/ssmstore_test.go b/store/ssmstore_test.go index 982a350d..43221df6 100644 --- a/store/ssmstore_test.go +++ b/store/ssmstore_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "errors" "os" "sort" @@ -8,39 +9,32 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/client" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/service/ssm" - "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ssm" + "github.com/aws/aws-sdk-go-v2/service/ssm/types" "github.com/stretchr/testify/assert" ) -type mockSSMClient struct { - ssmiface.SSMAPI - parameters map[string]mockParameter -} - type mockParameter struct { - currentParam *ssm.Parameter - history []*ssm.ParameterHistory - meta *ssm.ParameterMetadata + currentParam *types.Parameter + history []types.ParameterHistory + meta *types.ParameterMetadata } -func (m *mockSSMClient) PutParameter(i *ssm.PutParameterInput) (*ssm.PutParameterOutput, error) { - current, ok := m.parameters[*i.Name] +func mockPutParameter(i *ssm.PutParameterInput, parameters map[string]mockParameter) (*ssm.PutParameterOutput, error) { + current, ok := parameters[*i.Name] if !ok { current = mockParameter{ - history: []*ssm.ParameterHistory{}, + history: []types.ParameterHistory{}, } } - current.currentParam = &ssm.Parameter{ + current.currentParam = &types.Parameter{ Name: i.Name, Type: i.Type, Value: i.Value, } - current.meta = &ssm.ParameterMetadata{ + current.meta = &types.ParameterMetadata{ Description: i.Description, KeyId: i.KeyId, LastModifiedDate: aws.Time(time.Now()), @@ -48,7 +42,7 @@ func (m *mockSSMClient) PutParameter(i *ssm.PutParameterInput) (*ssm.PutParamete Name: i.Name, Type: i.Type, } - history := &ssm.ParameterHistory{ + history := types.ParameterHistory{ Description: current.meta.Description, KeyId: current.meta.KeyId, LastModifiedDate: current.meta.LastModifiedDate, @@ -59,42 +53,42 @@ func (m *mockSSMClient) PutParameter(i *ssm.PutParameterInput) (*ssm.PutParamete } current.history = append(current.history, history) - m.parameters[*i.Name] = current + parameters[*i.Name] = current return &ssm.PutParameterOutput{}, nil } -func (m *mockSSMClient) GetParameters(i *ssm.GetParametersInput) (*ssm.GetParametersOutput, error) { - parameters := []*ssm.Parameter{} +func mockGetParameters(i *ssm.GetParametersInput, parameters map[string]mockParameter) (*ssm.GetParametersOutput, error) { + returnParameters := []types.Parameter{} - for _, param := range m.parameters { + for _, param := range parameters { if paramNameInSlice(param.meta.Name, i.Names) { if *i.WithDecryption == false { - parameters = append(parameters, &ssm.Parameter{ + returnParameters = append(returnParameters, types.Parameter{ Name: param.meta.Name, Value: nil, }) } else { - parameters = append(parameters, param.currentParam) + returnParameters = append(returnParameters, *param.currentParam) } } } if len(parameters) == 0 { return &ssm.GetParametersOutput{ - Parameters: parameters, + Parameters: returnParameters, }, ErrSecretNotFound } return &ssm.GetParametersOutput{ - Parameters: parameters, + Parameters: returnParameters, }, nil } -func (m *mockSSMClient) GetParameterHistory(i *ssm.GetParameterHistoryInput) (*ssm.GetParameterHistoryOutput, error) { - history := []*ssm.ParameterHistory{} +func mockGetParameterHistory(i *ssm.GetParameterHistoryInput, parameters map[string]mockParameter) (*ssm.GetParameterHistoryOutput, error) { + history := []types.ParameterHistory{} - param, ok := m.parameters[*i.Name] + param, ok := parameters[*i.Name] if !ok { return &ssm.GetParameterHistoryOutput{ NextToken: nil, @@ -110,7 +104,7 @@ func (m *mockSSMClient) GetParameterHistory(i *ssm.GetParameterHistoryInput) (*s } for _, hist := range param.history { - history = append(history, &ssm.ParameterHistory{ + history = append(history, types.ParameterHistory{ Description: hist.Description, KeyId: hist.KeyId, LastModifiedDate: hist.LastModifiedDate, @@ -126,10 +120,10 @@ func (m *mockSSMClient) GetParameterHistory(i *ssm.GetParameterHistoryInput) (*s }, nil } -func (m *mockSSMClient) DescribeParameters(i *ssm.DescribeParametersInput) (*ssm.DescribeParametersOutput, error) { - parameters := []*ssm.ParameterMetadata{} +func mockDescribeParameters(i *ssm.DescribeParametersInput, parameters map[string]mockParameter) (*ssm.DescribeParametersOutput, error) { + returnMetadata := []types.ParameterMetadata{} - for _, param := range m.parameters { + for _, param := range parameters { match, err := matchFilters(i.Filters, param) if err != nil { return &ssm.DescribeParametersOutput{}, err @@ -140,20 +134,20 @@ func (m *mockSSMClient) DescribeParameters(i *ssm.DescribeParametersInput) (*ssm } if match && matchStringFilters { - parameters = append(parameters, param.meta) + returnMetadata = append(returnMetadata, *param.meta) } } return &ssm.DescribeParametersOutput{ - Parameters: parameters, + Parameters: returnMetadata, NextToken: nil, }, nil } -func (m *mockSSMClient) GetParametersByPath(i *ssm.GetParametersByPathInput) (*ssm.GetParametersByPathOutput, error) { - parameters := []*ssm.Parameter{} +func mockGetParametersByPath(i *ssm.GetParametersByPathInput, parameters map[string]mockParameter) (*ssm.GetParametersByPathOutput, error) { + returnParameters := []types.Parameter{} - for _, param := range m.parameters { + for _, param := range parameters { // Match ParameterFilters doesMatchStringFilters, err := matchStringFilters(i.ParameterFilters, param) if err != nil { @@ -163,107 +157,81 @@ func (m *mockSSMClient) GetParametersByPath(i *ssm.GetParametersByPathInput) (*s doesMatchPathFilter := *i.Path == "/" || strings.HasPrefix(*param.meta.Name, *i.Path) if doesMatchStringFilters && doesMatchPathFilter { - parameters = append(parameters, param.currentParam) + returnParameters = append(returnParameters, *param.currentParam) } } return &ssm.GetParametersByPathOutput{ - Parameters: parameters, + Parameters: returnParameters, NextToken: nil, }, nil } -func (m *mockSSMClient) GetParametersByPathPages(i *ssm.GetParametersByPathInput, fn func(*ssm.GetParametersByPathOutput, bool) bool) error { - o, err := m.GetParametersByPath(i) - if err != nil { - return err - } - fn(o, true) - return nil -} - -func (m *mockSSMClient) DescribeParametersPages(i *ssm.DescribeParametersInput, fn func(*ssm.DescribeParametersOutput, bool) bool) error { - o, err := m.DescribeParameters(i) - if err != nil { - return err - } - fn(o, true) - return nil -} - -func (m *mockSSMClient) GetParameterHistoryPages(i *ssm.GetParameterHistoryInput, fn func(*ssm.GetParameterHistoryOutput, bool) bool) error { - o, err := m.GetParameterHistory(i) - if err != nil { - return err - } - fn(o, true) - return nil -} - -func (m *mockSSMClient) DeleteParameter(i *ssm.DeleteParameterInput) (*ssm.DeleteParameterOutput, error) { - _, ok := m.parameters[*i.Name] +func mockDeleteParameter(i *ssm.DeleteParameterInput, parameters map[string]mockParameter) (*ssm.DeleteParameterOutput, error) { + _, ok := parameters[*i.Name] if !ok { return &ssm.DeleteParameterOutput{}, errors.New("secret not found") } - delete(m.parameters, *i.Name) + delete(parameters, *i.Name) return &ssm.DeleteParameterOutput{}, nil } -func paramNameInSlice(name *string, slice []*string) bool { +func paramNameInSlice(name *string, slice []string) bool { for _, val := range slice { - if *val == *name { + if val == *name { return true } } return false } -func prefixInSlice(val *string, prefixes []*string) bool { +func anyPrefixInValue(val *string, prefixes []string) bool { for _, prefix := range prefixes { - if strings.HasPrefix(*val, *prefix) { + if strings.HasPrefix(*val, prefix) { return true } } return false } -func pathInSlice(val *string, paths []*string) bool { +func pathInSlice(val *string, paths []string) bool { tokens := strings.Split(*val, "/") if len(tokens) < 2 { return false } matchPath := "/" + tokens[1] for _, path := range paths { - if matchPath == *path { + if matchPath == path { return true } } return false } -func matchFilters(filters []*ssm.ParametersFilter, param mockParameter) (bool, error) { +func matchFilters(filters []types.ParametersFilter, param mockParameter) (bool, error) { for _, filter := range filters { var compareTo *string - switch *filter.Key { + switch filter.Key { case "Name": compareTo = param.meta.Name case "Type": - compareTo = param.meta.Type + typeString := string(param.meta.Type) + compareTo = &typeString case "KeyId": compareTo = param.meta.KeyId default: return false, errors.New("invalid filter key") } - if !prefixInSlice(compareTo, filter.Values) { + if !anyPrefixInValue(compareTo, filter.Values) { return false, nil } } return true, nil } -func matchStringFilters(filters []*ssm.ParameterStringFilter, param mockParameter) (bool, error) { +func matchStringFilters(filters []types.ParameterStringFilter, param mockParameter) (bool, error) { for _, filter := range filters { var compareTo *string switch *filter.Key { @@ -282,7 +250,7 @@ func matchStringFilters(filters []*ssm.ParameterStringFilter, param mockParamete if *filter.Option == "BeginsWith" { result := false for _, value := range filter.Values { - if strings.HasPrefix(*param.meta.Name, *value) { + if strings.HasPrefix(*param.meta.Name, value) { result = true } } @@ -294,85 +262,105 @@ func matchStringFilters(filters []*ssm.ParameterStringFilter, param mockParamete return true, nil } -func NewTestSSMStore(mock ssmiface.SSMAPI) *SSMStore { +func NewTestSSMStore(parameters map[string]mockParameter, usePaths bool) *SSMStore { return &SSMStore{ - svc: mock, + usePaths: usePaths, + svc: &apiSSMMock{ + DeleteParameterFunc: func(ctx context.Context, params *ssm.DeleteParameterInput, optFns ...func(*ssm.Options)) (*ssm.DeleteParameterOutput, error) { + return mockDeleteParameter(params, parameters) + }, + DescribeParametersFunc: func(ctx context.Context, params *ssm.DescribeParametersInput, optFns ...func(*ssm.Options)) (*ssm.DescribeParametersOutput, error) { + return mockDescribeParameters(params, parameters); + }, + GetParameterHistoryFunc: func(ctx context.Context, params *ssm.GetParameterHistoryInput, optFns ...func(*ssm.Options)) (*ssm.GetParameterHistoryOutput, error) { + return mockGetParameterHistory(params, parameters); + }, + GetParametersFunc: func(ctx context.Context, params *ssm.GetParametersInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersOutput, error) { + return mockGetParameters(params, parameters); + }, + GetParametersByPathFunc: func(ctx context.Context, params *ssm.GetParametersByPathInput, optFns ...func(*ssm.Options)) (*ssm.GetParametersByPathOutput, error) { + return mockGetParametersByPath(params, parameters); + }, + PutParameterFunc: func(ctx context.Context, params *ssm.PutParameterInput, optFns ...func(*ssm.Options)) (*ssm.PutParameterOutput, error) { + return mockPutParameter(params, parameters); + }, + }, } } func TestNewSSMStore(t *testing.T) { t.Run("Using region override should take precedence over other settings", func(t *testing.T) { os.Setenv("CHAMBER_AWS_REGION", "us-east-1") + defer os.Unsetenv("CHAMBER_AWS_REGION") os.Setenv("AWS_REGION", "us-west-1") + defer os.Unsetenv("AWS_REGION") os.Setenv("AWS_DEFAULT_REGION", "us-west-2") + defer os.Unsetenv("AWS_DEFAULT_REGION") s, err := NewSSMStore(1) assert.Nil(t, err) - assert.Equal(t, "us-east-1", aws.StringValue(s.svc.(*ssm.SSM).Config.Region)) - os.Unsetenv("CHAMBER_AWS_REGION") - os.Unsetenv("AWS_REGION") - os.Unsetenv("AWS_DEFAULT_REGION") + assert.Equal(t, "us-east-1", s.config.Region) }) t.Run("Should use AWS_REGION if it is set", func(t *testing.T) { os.Setenv("AWS_REGION", "us-west-1") + defer os.Unsetenv("AWS_REGION") s, err := NewSSMStore(1) assert.Nil(t, err) - assert.Equal(t, "us-west-1", aws.StringValue(s.svc.(*ssm.SSM).Config.Region)) - - os.Unsetenv("AWS_REGION") + assert.Equal(t, "us-west-1", s.config.Region) }) t.Run("Should use CHAMBER_AWS_SSM_ENDPOINT if set", func(t *testing.T) { os.Setenv("CHAMBER_AWS_SSM_ENDPOINT", "mycustomendpoint") + defer os.Unsetenv("CHAMBER_AWS_SSM_ENDPOINT") s, err := NewSSMStore(1) assert.Nil(t, err) - endpoint, err := s.svc.(*ssm.SSM).Config.EndpointResolver.EndpointFor(endpoints.SsmServiceID, endpoints.UsWest2RegionID) + endpoint, err := s.config.EndpointResolverWithOptions.ResolveEndpoint(ssm.ServiceID, "us-west-2") assert.Nil(t, err) assert.Equal(t, "mycustomendpoint", endpoint.URL) - - os.Unsetenv("CHAMBER_AWS_SSM_ENDPOINT") }) t.Run("Should use default AWS SSM endpoint if CHAMBER_AWS_SSM_ENDPOINT not set", func(t *testing.T) { s, err := NewSSMStore(1) assert.Nil(t, err) - endpoint, err := s.svc.(*ssm.SSM).Config.EndpointResolver.EndpointFor(endpoints.SsmServiceID, endpoints.UsWest2RegionID) - assert.Nil(t, err) - assert.Equal(t, "https://ssm.us-west-2.amazonaws.com", endpoint.URL) + _, err = s.config.EndpointResolverWithOptions.ResolveEndpoint(ssm.ServiceID, "us-west-2") + var notFoundError *aws.EndpointNotFoundError + assert.ErrorAs(t, err, ¬FoundError) }) - t.Run("Should set aws sdk min throttle delay to default", func(t *testing.T) { - s, err := NewSSMStore(1) - assert.Nil(t, err) - assert.Equal(t, DefaultMinThrottleDelay, s.svc.(*ssm.SSM).Config.Retryer.(client.DefaultRetryer).MinThrottleDelay) - }) + // FIXME minThrottleDelay is ignored + // t.Run("Should set aws sdk min throttle delay to default", func(t *testing.T) { + // s, err := NewSSMStore(1) + // assert.Nil(t, err) + // assert.Equal(t, DefaultMinThrottleDelay, s.svc.(*ssm.SSM).Config.Retryer.(client.DefaultRetryer).MinThrottleDelay) + // }) } -func TestNewSSMStoreMinThrottleDelay(t *testing.T) { - t.Run("Should configure aws sdk retryer - num max retries and min throttle delay", func(t *testing.T) { - s, err := NewSSMStoreWithMinThrottleDelay(2, time.Duration(1000)*time.Millisecond) - assert.Nil(t, err) - assert.Equal(t, 2, s.svc.(*ssm.SSM).Config.Retryer.(client.DefaultRetryer).NumMaxRetries) - assert.Equal(t, time.Duration(1000)*time.Millisecond, s.svc.(*ssm.SSM).Config.Retryer.(client.DefaultRetryer).MinThrottleDelay) - }) -} +// FIXME minThrottleDelay is ignored +// func TestNewSSMStoreMinThrottleDelay(t *testing.T) { +// t.Run("Should configure aws sdk retryer - num max retries and min throttle delay", func(t *testing.T) { +// s, err := NewSSMStoreWithMinThrottleDelay(2, time.Duration(1000)*time.Millisecond) +// assert.Nil(t, err) +// assert.Equal(t, 2, s.config.Retryer().MaxAttempts()) +// assert.Equal(t, time.Duration(1000)*time.Millisecond, s.svc.(*ssm.SSM).Config.Retryer.(client.DefaultRetryer).MinThrottleDelay) +// }) +// } func TestWrite(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) t.Run("Setting a new key should work", func(t *testing.T) { secretId := SecretId{Service: "test", Key: "mykey"} err := store.Write(secretId, "value") assert.Nil(t, err) - assert.Contains(t, mock.parameters, store.idToName(secretId)) - assert.Equal(t, "value", *mock.parameters[store.idToName(secretId)].currentParam.Value) - assert.Equal(t, "1", *mock.parameters[store.idToName(secretId)].meta.Description) - assert.Equal(t, 1, len(mock.parameters[store.idToName(secretId)].history)) + assert.Contains(t, parameters, store.idToName(secretId)) + assert.Equal(t, "value", *parameters[store.idToName(secretId)].currentParam.Value) + assert.Equal(t, "1", *parameters[store.idToName(secretId)].meta.Description) + assert.Equal(t, 1, len(parameters[store.idToName(secretId)].history)) }) t.Run("Setting a key twice should create a new version", func(t *testing.T) { @@ -382,16 +370,16 @@ func TestWrite(t *testing.T) { err = store.Write(secretId, "newvalue") assert.Nil(t, err) - assert.Contains(t, mock.parameters, store.idToName(secretId)) - assert.Equal(t, "newvalue", *mock.parameters[store.idToName(secretId)].currentParam.Value) - assert.Equal(t, "2", *mock.parameters[store.idToName(secretId)].meta.Description) - assert.Equal(t, 2, len(mock.parameters[store.idToName(secretId)].history)) + assert.Contains(t, parameters, store.idToName(secretId)) + assert.Equal(t, "newvalue", *parameters[store.idToName(secretId)].currentParam.Value) + assert.Equal(t, "2", *parameters[store.idToName(secretId)].meta.Description) + assert.Equal(t, 2, len(parameters[store.idToName(secretId)].history)) }) } func TestRead(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) secretId := SecretId{Service: "test", Key: "key"} store.Write(secretId, "value") store.Write(secretId, "second value") @@ -429,8 +417,8 @@ func TestRead(t *testing.T) { } func TestList(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) secrets := []SecretId{ {Service: "test", Key: "a"}, @@ -479,8 +467,8 @@ func TestList(t *testing.T) { } func TestListRaw(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) secrets := []SecretId{ {Service: "test", Key: "a"}, @@ -517,8 +505,8 @@ func TestListRaw(t *testing.T) { } func TestListRawWithPaths(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStoreWithPaths(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, true) secrets := []SecretId{ {Service: "test", Key: "a"}, @@ -555,8 +543,8 @@ func TestListRawWithPaths(t *testing.T) { } func TestHistory(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) secrets := []SecretId{ {Service: "test", Key: "new"}, @@ -591,25 +579,18 @@ func TestHistory(t *testing.T) { }) } -func NewTestSSMStoreWithPaths(mock ssmiface.SSMAPI) *SSMStore { - return &SSMStore{ - svc: mock, - usePaths: true, - } -} - func TestWritePaths(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStoreWithPaths(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, true) t.Run("Setting a new key should work", func(t *testing.T) { secretId := SecretId{Service: "test", Key: "mykey"} err := store.Write(secretId, "value") assert.Nil(t, err) - assert.Contains(t, mock.parameters, store.idToName(secretId)) - assert.Equal(t, "value", *mock.parameters[store.idToName(secretId)].currentParam.Value) - assert.Equal(t, "1", *mock.parameters[store.idToName(secretId)].meta.Description) - assert.Equal(t, 1, len(mock.parameters[store.idToName(secretId)].history)) + assert.Contains(t, parameters, store.idToName(secretId)) + assert.Equal(t, "value", *parameters[store.idToName(secretId)].currentParam.Value) + assert.Equal(t, "1", *parameters[store.idToName(secretId)].meta.Description) + assert.Equal(t, 1, len(parameters[store.idToName(secretId)].history)) }) t.Run("Setting a key twice should create a new version", func(t *testing.T) { @@ -619,16 +600,16 @@ func TestWritePaths(t *testing.T) { err = store.Write(secretId, "newvalue") assert.Nil(t, err) - assert.Contains(t, mock.parameters, store.idToName(secretId)) - assert.Equal(t, "newvalue", *mock.parameters[store.idToName(secretId)].currentParam.Value) - assert.Equal(t, "2", *mock.parameters[store.idToName(secretId)].meta.Description) - assert.Equal(t, 2, len(mock.parameters[store.idToName(secretId)].history)) + assert.Contains(t, parameters, store.idToName(secretId)) + assert.Equal(t, "newvalue", *parameters[store.idToName(secretId)].currentParam.Value) + assert.Equal(t, "2", *parameters[store.idToName(secretId)].meta.Description) + assert.Equal(t, 2, len(parameters[store.idToName(secretId)].history)) }) } func TestReadPaths(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStoreWithPaths(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, true) secretId := SecretId{Service: "test", Key: "key"} store.Write(secretId, "value") store.Write(secretId, "second value") @@ -666,8 +647,8 @@ func TestReadPaths(t *testing.T) { } func TestListPaths(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStoreWithPaths(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, true) secrets := []SecretId{ {Service: "test", Key: "a"}, @@ -716,8 +697,8 @@ func TestListPaths(t *testing.T) { } func TestHistoryPaths(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStoreWithPaths(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, true) secrets := []SecretId{ {Service: "test", Key: "new"}, @@ -753,8 +734,8 @@ func TestHistoryPaths(t *testing.T) { } func TestDelete(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - store := NewTestSSMStore(mock) + parameters := map[string]mockParameter{} + store := NewTestSSMStore(parameters, false) secretId := SecretId{Service: "test", Key: "key"} store.Write(secretId, "value") @@ -773,9 +754,8 @@ func TestDelete(t *testing.T) { } func TestValidations(t *testing.T) { - mock := &mockSSMClient{parameters: map[string]mockParameter{}} - pathStore := NewTestSSMStore(mock) - pathStore.usePaths = true + parameters := map[string]mockParameter{} + pathStore := NewTestSSMStore(parameters, true) validPathFormat := []string{ "/foo", @@ -813,7 +793,8 @@ func TestValidations(t *testing.T) { }) } - noPathStore := NewTestSSMStore(mock) + parameters = map[string]mockParameter{} + noPathStore := NewTestSSMStore(parameters, false) noPathStore.usePaths = false validNoPathFormat := []string{