diff --git a/client/client.go b/client/client.go index ed7cce7cb..3d62648d0 100644 --- a/client/client.go +++ b/client/client.go @@ -67,6 +67,8 @@ var allRegions = []string{ "sa-east-1", } +const defaultRegion = "us-east-1" + type Services struct { Autoscaling AutoscalingClient Cloudfront CloudfrontClient @@ -95,15 +97,36 @@ type Services struct { S3Manager S3ManagerClient } +type ServicesAccountRegionMap map[string]map[string]*Services + +// ServicesManager will hold the entire map of (account X region) services +type ServicesManager struct { + services ServicesAccountRegionMap +} + +func (s *ServicesManager) ServicesByAccountAndRegion(accountId string, region string) *Services { + if region == "" { + region = defaultRegion + } + return s.services[accountId][region] +} + +func (s *ServicesManager) InitServicesForAccountAndRegion(accountId string, region string, services Services) { + if s.services[accountId] == nil { + s.services[accountId] = make(map[string]*Services, len(allRegions)) + } + s.services[accountId][region] = &services +} + type Client struct { // Those are already normalized values after configure and this is why we don't want to hold // config directly. - regions []string - logLevel *string - maxRetries int - maxBackoff int - services map[string]*Services - logger hclog.Logger + regions []string + logLevel *string + maxRetries int + maxBackoff int + ServicesManager ServicesManager + logger hclog.Logger // this is set by table clientList AccountID string @@ -133,9 +156,11 @@ func (s3Manager S3Manager) GetBucketRegion(ctx context.Context, bucket string, o func NewAwsClient(logger hclog.Logger, regions []string) Client { return Client{ - services: map[string]*Services{}, - logger: logger, - regions: regions, + ServicesManager: ServicesManager{ + services: ServicesAccountRegionMap{}, + }, + logger: logger, + regions: regions, } } @@ -144,39 +169,35 @@ func (c *Client) Logger() hclog.Logger { } func (c *Client) Services() *Services { - return c.services[c.AccountID] + return c.ServicesManager.ServicesByAccountAndRegion(c.AccountID, c.Region) } func (c *Client) withAccountID(accountID string) *Client { return &Client{ - regions: c.regions, - logLevel: c.logLevel, - maxRetries: c.maxRetries, - maxBackoff: c.maxBackoff, - services: c.services, - logger: c.logger.With("account_id", accountID), - AccountID: accountID, - Region: c.Region, + regions: c.regions, + logLevel: c.logLevel, + maxRetries: c.maxRetries, + maxBackoff: c.maxBackoff, + ServicesManager: c.ServicesManager, + logger: c.logger.With("account_id", accountID), + AccountID: accountID, + Region: c.Region, } } func (c *Client) withAccountIDAndRegion(accountID string, region string) *Client { return &Client{ - regions: c.regions, - logLevel: c.logLevel, - maxRetries: c.maxRetries, - maxBackoff: c.maxBackoff, - services: c.services, - logger: c.logger.With("account_id", accountID, "Region", region), - AccountID: accountID, - Region: region, + regions: c.regions, + logLevel: c.logLevel, + maxRetries: c.maxRetries, + maxBackoff: c.maxBackoff, + ServicesManager: c.ServicesManager, + logger: c.logger.With("account_id", accountID, "Region", region), + AccountID: accountID, + Region: region, } } -func (c Client) SetAccountServices(accountId string, s Services) { - c.services[accountId] = &s -} - func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMeta, error) { ctx := context.Background() awsConfig := providerConfig.(*Config) @@ -199,19 +220,31 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe var awsCfg aws.Config // This is a try to solve https://aws.amazon.com/premiumsupport/knowledge-center/iam-validate-access-credentials/ // with this https://github.com/aws/aws-sdk-go-v2/issues/515#issuecomment-607387352 - defaultRegion := "us-east-1" switch { case account.ID != "default" && account.RoleARN != "": // assume role if specified (SDK takes it from default or env var: AWS_PROFILE) - awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion)) + awsCfg, err = config.LoadDefaultConfig( + ctx, + config.WithDefaultRegion(defaultRegion), + config.WithRetryer(newRetryer(awsConfig.MaxRetries, awsConfig.MaxBackoff)), + ) if err != nil { return nil, err } awsCfg.Credentials = stscreds.NewAssumeRoleProvider(sts.NewFromConfig(awsCfg), account.RoleARN) case account.ID != "default": - awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion), config.WithSharedConfigProfile(account.ID)) + awsCfg, err = config.LoadDefaultConfig( + ctx, + config.WithDefaultRegion(defaultRegion), + config.WithSharedConfigProfile(account.ID), + config.WithRetryer(newRetryer(awsConfig.MaxRetries, awsConfig.MaxBackoff)), + ) default: - awsCfg, err = config.LoadDefaultConfig(ctx, config.WithDefaultRegion(defaultRegion)) + awsCfg, err = config.LoadDefaultConfig( + ctx, + config.WithDefaultRegion(defaultRegion), + config.WithRetryer(newRetryer(awsConfig.MaxRetries, awsConfig.MaxBackoff)), + ) } if err != nil { @@ -221,7 +254,6 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe if awsConfig.AWSDebug { awsCfg.ClientLogMode = aws.LogRequest | aws.LogResponse | aws.LogRetries } - awsCfg.Retryer = newRetryer(awsConfig.MaxRetries, awsConfig.MaxBackoff) svc := sts.NewFromConfig(awsCfg) output, err := svc.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}, func(o *sts.Options) { o.Region = "aws-global" @@ -246,7 +278,9 @@ func Configure(logger hclog.Logger, providerConfig interface{}) (schema.ClientMe client.AccountID = *output.Account client.Region = client.regions[0] } - client.SetAccountServices(*output.Account, initServices(awsCfg)) + for _, region := range client.regions { + client.ServicesManager.InitServicesForAccountAndRegion(*output.Account, region, initServices(awsCfg)) + } } return &client, nil @@ -284,7 +318,10 @@ func initServices(awsCfg aws.Config) Services { func newRetryer(maxRetries int, maxBackoff int) func() aws.Retryer { return func() aws.Retryer { - return retry.AddWithMaxBackoffDelay(retry.AddWithMaxAttempts(retry.NewStandard(), maxRetries), time.Second*time.Duration(maxBackoff)) + return retry.NewStandard(func(o *retry.StandardOptions) { + o.MaxAttempts = maxRetries + o.MaxBackoff = time.Second * time.Duration(maxBackoff) + }) } } diff --git a/client/mocks/mock_test.go b/client/mocks/mock_test.go index dea1941c3..eaa2aa63b 100644 --- a/client/mocks/mock_test.go +++ b/client/mocks/mock_test.go @@ -351,8 +351,8 @@ func TestResources(t *testing.T) { Configure: func(logger hclog.Logger, i interface{}) (schema.ClientMeta, error) { c := client.NewAwsClient(logging.New(&hclog.LoggerOptions{ Level: hclog.Warn, - }), []string{"test-1"}) - c.SetAccountServices("testAccount", tc.mockBuilder(t, ctrl)) + }), []string{"us-east-1"}) + c.ServicesManager.InitServicesForAccountAndRegion("testAccount", "us-east-1", tc.mockBuilder(t, ctrl)) return &c, nil }, }) diff --git a/client/multiplexers.go b/client/multiplexers.go index ce321c504..0ee83b9d8 100644 --- a/client/multiplexers.go +++ b/client/multiplexers.go @@ -5,7 +5,7 @@ import "github.com/cloudquery/cq-provider-sdk/provider/schema" func AccountMultiplex(meta schema.ClientMeta) []schema.ClientMeta { var l = make([]schema.ClientMeta, 0) client := meta.(*Client) - for accountID := range client.services { + for accountID := range client.ServicesManager.services { l = append(l, client.withAccountID(accountID)) } return l @@ -14,7 +14,7 @@ func AccountMultiplex(meta schema.ClientMeta) []schema.ClientMeta { func AccountRegionMultiplex(meta schema.ClientMeta) []schema.ClientMeta { var l = make([]schema.ClientMeta, 0) client := meta.(*Client) - for accountID := range client.services { + for accountID := range client.ServicesManager.services { for _, region := range client.regions { l = append(l, client.withAccountIDAndRegion(accountID, region)) } diff --git a/resources/provider_test.go b/resources/provider_test.go index d0dd0b5b0..ecb2edc2c 100644 --- a/resources/provider_test.go +++ b/resources/provider_test.go @@ -28,8 +28,8 @@ func awsTestHelper(t *testing.T, table *schema.Table, builder func(*testing.T, * Configure: func(logger hclog.Logger, i interface{}) (schema.ClientMeta, error) { c := client.NewAwsClient(logging.New(&hclog.LoggerOptions{ Level: hclog.Warn, - }), []string{"test-1"}) - c.SetAccountServices("testAccount", builder(t, ctrl)) + }), []string{"us-east-1"}) + c.ServicesManager.InitServicesForAccountAndRegion("testAccount", "us-east-1", builder(t, ctrl)) return &c, nil }, })