diff --git a/GNUmakefile b/GNUmakefile index 04a0206..e913084 100755 --- a/GNUmakefile +++ b/GNUmakefile @@ -23,7 +23,7 @@ test: fmtcheck xargs -t -n4 go test $(TESTARGS) -timeout=30s -parallel=4 testacc: fmtcheck - TF_ACC=1 go test $(TEST) -v $(TESTARGS) -timeout 120m + TF_ACC=1 go test $(TEST) -v $(TESTARGS) -timeout 120m -parallel=4 vet: @echo "go vet ." diff --git a/dsfhub/provider_test.go b/dsfhub/provider_test.go index 806f7f3..b99626e 100644 --- a/dsfhub/provider_test.go +++ b/dsfhub/provider_test.go @@ -17,18 +17,18 @@ var testAccProviderConfigure sync.Once func init() { testAccProvider = Provider() testAccProviders = map[string]*schema.Provider{ - "dsf": testAccProvider, + "dsfhub": testAccProvider, } } -//func TestProvider(t *testing.T) { -// log.Printf("======================== BEGIN TEST ========================") -// log.Printf("[DEBUG] Running test TestProvider") -// if err := Provider().InternalValidate(); err != nil { -// log.Printf("[INFO] err: %s \n", err) -// t.Fatalf("err: %s", err) -// } -//} +func TestProvider(t *testing.T) { + log.Printf("======================== BEGIN TEST ========================") + log.Printf("[DEBUG] Running test TestProvider") + if err := Provider().InternalValidate(); err != nil { + log.Printf("[INFO] err: %s \n", err) + t.Fatalf("err: %s", err) + } +} func TestProvider_impl(t *testing.T) { var _ *schema.Provider = Provider() @@ -38,12 +38,12 @@ func testAccPreCheck(t *testing.T) { log.Printf("======================== BEGIN TEST ========================") log.Printf("[INFO] Running test testAccPreCheck \n") testAccProviderConfigure.Do(func() { - if v := os.Getenv("DSF_TOKEN"); v == "" { - t.Fatal("DSF_TOKEN must be set for acceptance tests") + if v := os.Getenv("DSFHUB_TOKEN"); v == "" { + t.Fatal("DSFHUB_TOKEN must be set for acceptance tests") } - if v := os.Getenv("DSF_HOST"); v == "" { - t.Fatal("DSF_HOST must be set for acceptance tests") + if v := os.Getenv("DSFHUB_HOST"); v == "" { + t.Fatal("DSFHUB_HOST must be set for acceptance tests") } err := testAccProvider.Configure(context.Background(), terraform.NewResourceConfigRaw(nil)) diff --git a/dsfhub/resource_cloud_account.go b/dsfhub/resource_cloud_account.go index 8911cb2..8addec3 100644 --- a/dsfhub/resource_cloud_account.go +++ b/dsfhub/resource_cloud_account.go @@ -2,19 +2,21 @@ package dsfhub import ( "bytes" + "context" "fmt" "log" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) func resourceCloudAccount() *schema.Resource { return &schema.Resource{ - Create: resourceCloudAccountCreate, - Read: resourceCloudAccountRead, - Update: resourceCloudAccountUpdate, - Delete: resourceCloudAccountDelete, + CreateContext: resourceCloudAccountCreateContext, + ReadContext: resourceCloudAccountReadContext, + UpdateContext: resourceCloudAccountUpdateContext, + DeleteContext: resourceCloudAccountDeleteContext, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, @@ -109,14 +111,13 @@ func resourceCloudAccount() *schema.Resource { Description: "The Access key ID of AWS secret access key used to authenticate", Optional: true, Default: nil, - Computed: true, }, "access_key": { Type: schema.TypeString, Description: "The Secret access key used to authenticate", Optional: true, Default: nil, - Computed: true, + Required: false, }, "amazon_secret": { Type: schema.TypeSet, @@ -161,7 +162,6 @@ func resourceCloudAccount() *schema.Resource { Description: "This is also referred to as the Client ID and it’s the unique identifier for the registered application being used to execute Python SDK commands against Azure’s API services. You can find this number under Azure Active Directory -> App Registrations -> Owned Applications", Optional: true, Default: nil, - Computed: true, }, "auth_mechanism": { Type: schema.TypeString, @@ -339,7 +339,8 @@ func resourceCloudAccount() *schema.Resource { Description: "The Secret access key used to authenticate", Required: false, Optional: true, - Default: false, + Default: nil, + Sensitive: true, }, "ssl": { Type: schema.TypeBool, @@ -496,31 +497,36 @@ func resourceCloudAccount() *schema.Resource { } } -func resourceCloudAccountCreate(d *schema.ResourceData, m interface{}) error { +func resourceCloudAccountCreateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) if isOk, err := checkResourceRequiredFields(requiredCloudAccountJson, ignoreCloudAccountParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // check provided fields against schema cloudAccount := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&cloudAccount, serverType, d) + // create resource log.Printf("[INFO] Creating CloudAccount for serverType: %s and gatewayId: %s gatewayId: \n", serverType, cloudAccount.Data.GatewayID) createCloudAccountResponse, err := client.CreateCloudAccount(cloudAccount) - if err != nil { log.Printf("[ERROR] adding CloudAccount for serverType: %s and gatewayId: %s | err: %s", serverType, cloudAccount.Data.GatewayID, err) - return err + return diag.FromErr(err) } + // set ID cloudAccountId := createCloudAccountResponse.Data.AssetData.AssetID d.SetId(cloudAccountId) // Set the rest of the state from the resource read - return resourceCloudAccountRead(d, m) + resourceCloudAccountReadContext(ctx, d, m) + + return nil } -func resourceCloudAccountRead(d *schema.ResourceData, m interface{}) error { +func resourceCloudAccountReadContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) cloudAccountId := d.Id() @@ -530,7 +536,7 @@ func resourceCloudAccountRead(d *schema.ResourceData, m interface{}) error { if err != nil { log.Printf("[ERROR] Reading cloudAccountReadResponse with cloudAccountId: %s | err: %s\n", cloudAccountId, err) - return err + return diag.FromErr(err) } if cloudAccountReadResponse != nil { @@ -546,7 +552,9 @@ func resourceCloudAccountRead(d *schema.ResourceData, m interface{}) error { d.Set("asset_id", cloudAccountReadResponse.Data.AssetData.AssetID) d.Set("asset_source", cloudAccountReadResponse.Data.AssetData.AssetSource) d.Set("available_regions", cloudAccountReadResponse.Data.AssetData.AvailableRegions) - d.Set("credential_endpoint", cloudAccountReadResponse.Data.AssetData.CredentialsEndpoint) + if cloudAccountReadResponse.Data.AssetData.CredentialsEndpoint != "" { + d.Set("credential_endpoint", cloudAccountReadResponse.Data.AssetData.CredentialsEndpoint) + } d.Set("criticality", cloudAccountReadResponse.Data.AssetData.Criticality) d.Set("gateway_id", cloudAccountReadResponse.Data.GatewayID) d.Set("jsonar_uid", cloudAccountReadResponse.Data.AssetData.JsonarUID) @@ -642,38 +650,45 @@ func resourceCloudAccountRead(d *schema.ResourceData, m interface{}) error { connections.Add(connection) } - d.Set("ca_connection", connections) + d.Set("asset_connection", connections) log.Printf("[INFO] Finished reading CloudAccount with cloudAccountId: %s\n", cloudAccountId) return nil } -func resourceCloudAccountUpdate(d *schema.ResourceData, m interface{}) error { +func resourceCloudAccountUpdateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) + + // check provided fields against schema cloudAccountId := d.Id() if isOk, err := checkResourceRequiredFields(requiredCloudAccountJson, ignoreCloudAccountParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload cloudAccount := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&cloudAccount, serverType, d) + // update resource log.Printf("[INFO] Updating CloudAccount for serverType: %s and gatewayId: %s assetId: %s\n", cloudAccount.Data.ServerType, cloudAccount.Data.GatewayID, cloudAccount.Data.AssetData.AssetID) _, err := client.UpdateCloudAccount(cloudAccountId, cloudAccount) - if err != nil { log.Printf("[ERROR] Updating CloudAccount for serverType: %s and gatewayId: %s assetId: %s | err:%s\n", cloudAccount.Data.ServerType, cloudAccount.Data.GatewayID, cloudAccount.Data.AssetData.AssetID, err) - return err + return diag.FromErr(err) } + // set ID d.SetId(cloudAccountId) // Set the rest of the state from the resource read - return resourceCloudAccountRead(d, m) + resourceCloudAccountReadContext(ctx, d, m) + + return nil } -func resourceCloudAccountDelete(d *schema.ResourceData, m interface{}) error { +func resourceCloudAccountDeleteContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) cloudAccountId := d.Id() diff --git a/dsfhub/resource_cloud_account_test.go b/dsfhub/resource_cloud_account_test.go index e83a043..da8bff5 100644 --- a/dsfhub/resource_cloud_account_test.go +++ b/dsfhub/resource_cloud_account_test.go @@ -3,36 +3,40 @@ package dsfhub import ( "fmt" "log" + "os" "testing" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" ) -const cloudAccountResourceName = "cloud_account" -const cloudAccountType = "aws" -const cloudAccountResourceTypeAndName = cloudAccountResourceName + "." + cloudAccountType +func TestAccDSFCloudAccount_Aws(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const ( + assetId = "arn:aws:iam::123456789012" + resourceName = "aws-cloud-account" + ) -func TestAccCloudAccount_basic(t *testing.T) { - log.Printf("======================== BEGIN TEST ========================") - log.Printf("[INFO] Running test TestAccCloudAccount_basic \n") - resource.Test(t, resource.TestCase{ + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfCloudAccountResourceType, resourceName) + + resource.ParallelTest(t, resource.TestCase{ PreCheck: func() { testAccPreCheck(t) }, Providers: testAccProviders, CheckDestroy: testAccCloudAccountDestroy, Steps: []resource.TestStep{ + {Config: testAccDSFCloudAccountConfig_Aws(resourceName, gatewayId, assetId, "default")}, + {Config: testAccDSFCloudAccountConfig_Aws(resourceName, gatewayId, assetId, "iam_role")}, + // {Config: testAccDSFCloudAccountConfig_Aws(resourceName, gatewayId, assetId, "key")}, //TODO: fix "key" failing refresh + {Config: testAccDSFCloudAccountConfig_Aws(resourceName, gatewayId, assetId, "profile")}, + // validate import { - Config: testAccCheckCloudAccountConfigBasic(t), - Check: resource.ComposeTestCheckFunc( - testCheckCloudAccountExists(cloudAccountResourceName), - resource.TestCheckResourceAttr(cloudAccountResourceTypeAndName, cloudAccountResourceName, cloudAccountType), - ), - }, - { - ResourceName: cloudAccountResourceTypeAndName, + ResourceName: resourceTypeAndName, ImportState: true, ImportStateVerify: true, - ImportStateIdFunc: testAccCloudAccountId, }, }, }) @@ -41,7 +45,7 @@ func TestAccCloudAccount_basic(t *testing.T) { func testAccCloudAccountId(state *terraform.State) (string, error) { log.Printf("[INFO] Running test testAccCloudAccountId \n") for _, rs := range state.RootModule().Resources { - if rs.Type != cloudAccountType { + if rs.Type != dsfCloudAccountResourceType { continue } return fmt.Sprintf("%s", rs.Primary.ID), nil @@ -69,24 +73,11 @@ func testCheckCloudAccountExists(dataSourceId string) resource.TestCheckFunc { } } -func testAccCheckCloudAccountConfigBasic(t *testing.T) string { - log.Printf("[INFO] Running test testAccCheckCloudAccountConfigBasic \n") - return fmt.Sprintf(` -resource "%s" "my_test_data_source" { - admin_email = "%s" - arn = "%s" - asset_display_name = "%s" - gateway_id = %s - server_host_name = "%s" - server_type = "%s" -}`, cloudAccountResourceName, testAdminEmail, testArn, testAssetDisplayName, testGatewayId, testServerHostName, testDSServerType) -} - func testAccCloudAccountDestroy(state *terraform.State) error { log.Printf("[INFO] Running test testAccCloudAccountDestroy \n") client := testAccProvider.Meta().(*Client) for _, res := range state.RootModule().Resources { - if res.Type != "dsf_data_source" { + if res.Type != "dsfhub_data_source" { continue } cloudAccountId := res.Primary.ID diff --git a/dsfhub/resource_cloud_account_test_data.go b/dsfhub/resource_cloud_account_test_data.go new file mode 100644 index 0000000..7a878fe --- /dev/null +++ b/dsfhub/resource_cloud_account_test_data.go @@ -0,0 +1,53 @@ +package dsfhub + +import "fmt" + +// Output a terraform config for an AWS cloud account resource. +// +// Supports all authentication mechanisms: "key", "profile", "iam_role", and +// "default". +func testAccDSFCloudAccountConfig_Aws(resourceName string, gatewayId string, assetId string, authMechanism string) string { + var assetConnectionBlock string + + if authMechanism == "key" { + assetConnectionBlock = fmt.Sprintf(` + asset_connection { + access_id = "my-access-id" + auth_mechanism = "` + authMechanism + `" + reason = "default" + region = "us-east-1" + secret_key = "my-secret-key" + } + `) + } else if authMechanism == "profile" { + assetConnectionBlock = fmt.Sprintf(` + asset_connection { + auth_mechanism = "` + authMechanism + `" + reason = "default" + region = "us-east-2" + username = "dsfhubuser" + } + `) + } else { + assetConnectionBlock = fmt.Sprintf(` + asset_connection { + auth_mechanism = "` + authMechanism + `" + reason = "default" + region = "us-west-1" + } + `) + } + + return fmt.Sprintf(` +resource "`+dsfCloudAccountResourceType+`" "%[1]s" { + server_type = "AWS" + + admin_email = "`+testAdminEmail+`" + arn = "%[3]s" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + gateway_id = "%[2]s" + + `+assetConnectionBlock+` +}`, resourceName, gatewayId, assetId) +} diff --git a/dsfhub/resource_common.go b/dsfhub/resource_common.go index 25d556f..b4bfba7 100644 --- a/dsfhub/resource_common.go +++ b/dsfhub/resource_common.go @@ -2,13 +2,17 @@ package dsfhub import ( "bytes" + "context" "encoding/json" "fmt" "log" "reflect" + "regexp" + "strconv" "strings" "time" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/retry" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) @@ -301,7 +305,7 @@ func checkResourceRequiredFields(requiredFieldsJson string, ignoreParamsByServer serverType := d.Get("server_type").(string) serverTypeObj, found := requiredFields.ServerType[serverType] if !found { - return false, fmt.Errorf("[DEBUG] Unsupported serverType: %s\n", serverType) + return false, fmt.Errorf("unsupported serverType: %s\n", serverType) } for _, field := range serverTypeObj.Required { curField := d.Get(field) @@ -323,7 +327,7 @@ func checkResourceRequiredFields(requiredFieldsJson string, ignoreParamsByServer log.Printf("[DEBUG] Checking for authMechanism: %s\n", authMechanism) authMechanismFields, found := serverTypeObj.AuthMechanisms[authMechanism] if !found { - return false, fmt.Errorf("[DEBUG] Unsupported authMechanism '%v' for serverType '%v'\n", authMechanism, serverType) + return false, fmt.Errorf("unsupported authMechanism '%v' for serverType '%v'\n", authMechanism, serverType) } for _, field := range authMechanismFields { log.Printf("[DEBUG] Checking for field: '%s', value: '%s'\n", field, connection[field]) @@ -339,7 +343,7 @@ func checkResourceRequiredFields(requiredFieldsJson string, ignoreParamsByServer } } if len(missingParams) > 0 { - return false, fmt.Errorf("[DEBUG] Missing required fields for dsf_data_source with serverType '%s', missing fields: %s\n", serverType, "\""+strings.Join(missingParams, ", ")+"\"") + return false, fmt.Errorf("missing required fields for dsfhub_data_source with serverType '%s', missing fields: %s\n", serverType, "\""+strings.Join(missingParams, ", ")+"\"") } else { return true, nil } @@ -408,13 +412,64 @@ func contains(l []string, x string) bool { return false } -func connectDisconnectGateway(d *schema.ResourceData, dsfDataSource ResourceWrapper, m interface{}) error { - // func connectDisconnectGateway(d *schema.ResourceData, dsfDataSource ResourceWrapper, m interface{}) { +func waitUntilAuditState(ctx context.Context, desiredState bool, resourceType string, assetId string, m interface{}) error { client := m.(*Client) - // give enough time for connect/disconnect gateway playbook to complete - wait := 6 * time.Second + pendingState := strconv.FormatBool(!desiredState) + targetState := strconv.FormatBool(desiredState) + stateChangeConf := &retry.StateChangeConf{ + Pending: []string{ + pendingState, + }, + Target: []string{ + targetState, + }, + Refresh: auditStateRefreshFunc(*client, resourceType, assetId), + Timeout: 8 * time.Minute, + Delay: 10 * time.Second, + MinTimeout: 5 * time.Second, + } + + _, err := stateChangeConf.WaitForStateContext(ctx) + if err != nil { + log.Printf("[ERROR] error waiting for audit collection state to update to %v for asset %v", desiredState, assetId) + return err + } + + return nil +} + +func auditStateRefreshFunc(client Client, resourceType string, assetId string) retry.StateRefreshFunc { + return func() (any, string, error) { + var result *ResourceWrapper + var err error + + switch resourceType { + case dsfDataSourceResourceType: + { + log.Printf("[INFO] checking audit state for data_source asset %v", assetId) + result, err = client.ReadDSFDataSource(assetId) + } + case dsfLogAggregatorResourceType: + { + log.Printf("[INFO] checking audit state for log_aggregator asset %v", assetId) + result, err = client.ReadLogAggregator(assetId) + } + default: + { + return nil, "", fmt.Errorf("invalid resourceType: %v", resourceType) + } + } + if err != nil { + return 0, "", err + } + + return result, strconv.FormatBool(result.Data.AssetData.AuditPullEnabled), nil + } +} + +func connectDisconnectGateway(ctx context.Context, d *schema.ResourceData, resourceType string, m interface{}) error { assetId := d.Get("asset_id").(string) auditPullEnabled := d.Get("audit_pull_enabled").(bool) auditType := d.Get("audit_type").(string) @@ -430,47 +485,25 @@ func connectDisconnectGateway(d *schema.ResourceData, dsfDataSource ResourceWrap // if audit_pull_enabled has been changed, connect/disconnect from gateway as needed if auditPullEnabledChanged { if auditPullEnabled { - // allow time for asset syncs to gateways to finish - time.Sleep(wait) - - // connect gateway - _, err := client.EnableAuditDSFDataSource(assetId) + err := connectGateway(ctx, m, assetId, resourceType) if err != nil { - log.Printf("[INFO] Error enabling audit for assetId: %s\n", assetId) return err } - time.Sleep(wait) - - // disconnect gateway } else { - _, err := client.DisableAuditDSFDataSource(assetId) + err := disconnectGateway(ctx, m, assetId, resourceType) if err != nil { - log.Printf("[INFO] Error disabling audit for assetId: %s\n", assetId) return err } - time.Sleep(wait) } // if asset is already connected, check whether relevant fields have been updated and reconnect to gateway } else if auditPullEnabled { if auditTypeChanged { origAuditType, newAuditType := d.GetChange("audit_type") log.Printf("[INFO] auditType value has changed from %s to %s, reconnecting asset to gateway\n", origAuditType, newAuditType) - - // disconnect - _, err1 := client.DisableAuditDSFDataSource(assetId) - if err1 != nil { - log.Printf("[INFO] Error disabling audit for assetId: %s\n", assetId) - return err1 - } - time.Sleep(wait) - - // reconnect - _, err2 := client.EnableAuditDSFDataSource(assetId) - if err2 != nil { - log.Printf("[INFO] Error enabling audit for assetId: %s\n", assetId) - return err2 + err := reconnectGateway(ctx, m, assetId, resourceType) + if err != nil { + return err } - time.Sleep(wait) } } else { log.Printf("[INFO] Asset %s does not need to be connected to or disconnected from gateway", assetId) @@ -478,6 +511,55 @@ func connectDisconnectGateway(d *schema.ResourceData, dsfDataSource ResourceWrap return nil } +func connectGateway(ctx context.Context, m interface{}, assetId string, resourceType string) error { + client := m.(*Client) + _, err := client.EnableAuditDSFDataSource(assetId) + if err != nil { + log.Printf("[INFO] Error enabling audit for assetId: %s\n", assetId) + return err + } + + err2 := waitUntilAuditState(ctx, true, resourceType, assetId, m) + if err2 != nil { + return err2 + } + + return nil +} + +func disconnectGateway(ctx context.Context, m interface{}, assetId string, resourceType string) error { + client := m.(*Client) + _, err := client.DisableAuditDSFDataSource(assetId) + if err != nil { + log.Printf("[INFO] Error disabling audit for assetId: %s\n", assetId) + return err + } + + err = waitUntilAuditState(ctx, false, resourceType, assetId, m) + if err != nil { + log.Printf("[INFO] Error while waiting for audit state to update for assetId: %s\n", assetId) + return err + } + + return nil +} + +func reconnectGateway(ctx context.Context, m interface{}, assetId string, resourceType string) error { + log.Printf("[INFO] Re-enabling audit for assetId: %s\n", assetId) + + err := disconnectGateway(ctx, m, assetId, resourceType) + if err != nil { + return err + } + + err = connectGateway(ctx, m, assetId, resourceType) + if err != nil { + return err + } + + return nil +} + // ConnectionData resource hash functions func resourceConnectionDataAmazonSecretHash(v interface{}) int { var buf bytes.Buffer @@ -584,3 +666,20 @@ func resourceAssetDataServiceEndpointsHash(v interface{}) int { } return PositiveHash(buf.String()) } + +// testAccParseResourceAttributeReference parses a terraform field and +// determines whether it is a reference to another resource. If the field is +// a reference, return the input string and if not, return it wrapped in +// double-quotes. +func testAccParseResourceAttributeReference(field string) string { + var regExpr string = `dsfhub_[A-Za-z0-9_-].+\.[A-Za-z0-9_-].+` //e.g. dsfhub_cloud_account.my-cloud-account, dsfhub_cloud_account.my-cloud-account.asset_id + var parsedField string + + isReference, _ := regexp.Match(regExpr, []byte(field)) + if isReference { + parsedField = field + } else { + parsedField = fmt.Sprintf("\"%s\"", field) + } + return parsedField +} diff --git a/dsfhub/resource_constants.go b/dsfhub/resource_constants.go new file mode 100644 index 0000000..2460620 --- /dev/null +++ b/dsfhub/resource_constants.go @@ -0,0 +1,8 @@ +package dsfhub + +const ( + dsfDataSourceResourceType = "dsfhub_data_source" + dsfLogAggregatorResourceType = "dsfhub_log_aggregator" + dsfCloudAccountResourceType = "dsfhub_cloud_account" + dsfSecretManagerResourceType = "dsfhub_secret_manager" +) diff --git a/dsfhub/resource_data_source.go b/dsfhub/resource_data_source.go index ed8bdda..ac1a618 100644 --- a/dsfhub/resource_data_source.go +++ b/dsfhub/resource_data_source.go @@ -2,19 +2,21 @@ package dsfhub import ( "bytes" + "context" "fmt" "log" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) func resourceDSFDataSource() *schema.Resource { return &schema.Resource{ - Create: resourceDSFDataSourceCreate, - Read: resourceDSFDataSourceRead, - Update: resourceDSFDataSourceUpdate, - Delete: resourceDSFDataSourceDelete, + CreateContext: resourceDSFDataSourceCreateContext, + ReadContext: resourceDSFDataSourceReadContext, + UpdateContext: resourceDSFDataSourceUpdateContext, + DeleteContext: resourceDSFDataSourceDeleteContext, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, @@ -1322,45 +1324,61 @@ func resourceDSFDataSource() *schema.Resource { } } -func resourceDSFDataSourceCreate(d *schema.ResourceData, m interface{}) error { +func resourceDSFDataSourceCreateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { + var diags diag.Diagnostics client := m.(*Client) + + // check provided fields against schema if isOk, err := checkResourceRequiredFields(requiredDataSourceFieldsJson, ignoreDataSourceParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload dsfDataSource := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&dsfDataSource, serverType, d) + // auditPullEnabled set to false as connect/disconnect logic handled below dsfDataSource.Data.AssetData.AuditPullEnabled = false + + // create resource log.Printf("[INFO] Creating DSF data source for serverType: %s and gatewayId: %s \n", dsfDataSource.Data.ServerType, dsfDataSource.Data.GatewayID) dsfDataSourceResponse, err := client.CreateDSFDataSource(dsfDataSource) - if err != nil { log.Printf("[INFO] Creating DSF data source for serverType: %s and gatewayId: %s assetId: %s\n", dsfDataSource.Data.ServerType, dsfDataSource.Data.GatewayID, dsfDataSource.Data.AssetData.AssetID) - return err + return diag.FromErr(err) } // Connect/disconnect asset to gateway - connectDisconnectGateway(d, dsfDataSource, m) + err = connectDisconnectGateway(ctx, d, dsfDataSourceResourceType, m) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Warning, + Summary: fmt.Sprintf("Error while updating audit state for asset: %s", d.Get("asset_id")), + Detail: fmt.Sprintf("Error: %s\n", err), + }) + } // Set ID dsfDataSourceId := dsfDataSourceResponse.Data.AssetData.AssetID d.SetId(dsfDataSourceId) // Set the rest of the state from the resource read - return resourceDSFDataSourceRead(d, m) + log.Printf("[DEBUG] Writing data source asset details to state") + resourceDSFDataSourceReadContext(ctx, d, m) + + return diags } -func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { +func resourceDSFDataSourceReadContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) dsfDataSourceId := d.Id() log.Printf("[INFO] Reading DSF data source with dsfDataSourceId: %s\n", dsfDataSourceId) dsfDataSourceReadResponse, err := client.ReadDSFDataSource(dsfDataSourceId) - if err != nil { log.Printf("[ERROR] Reading dsfDataSourceReadResponse | err: %s\n", err) - return err + return diag.FromErr(err) } if dsfDataSourceReadResponse != nil { @@ -1373,7 +1391,10 @@ func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { d.Set("admin_email", dsfDataSourceReadResponse.Data.AssetData.AdminEmail) //d.Set("application", dsfDataSourceReadResponse.Data.AssetData.Application) //d.Set("archive", dsfDataSourceReadResponse.Data.AssetData.Archive) - d.Set("arn", dsfDataSourceReadResponse.Data.AssetData.Arn) + + if dsfDataSourceReadResponse.Data.AssetData.Arn != "" { + d.Set("arn", dsfDataSourceReadResponse.Data.AssetData.Arn) + } d.Set("asset_display_name", dsfDataSourceReadResponse.Data.AssetData.AssetDisplayName) d.Set("asset_id", dsfDataSourceReadResponse.Data.AssetData.AssetID) d.Set("asset_source", dsfDataSourceReadResponse.Data.AssetData.AssetSource) @@ -1428,7 +1449,9 @@ func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { d.Set("sdm_enabled", dsfDataSourceReadResponse.Data.AssetData.SdmEnabled) d.Set("searches", dsfDataSourceReadResponse.Data.AssetData.Searches) d.Set("server_host_name", dsfDataSourceReadResponse.Data.AssetData.ServerHostName) - d.Set("server_ip", dsfDataSourceReadResponse.Data.AssetData.ServerIP) + if dsfDataSourceReadResponse.Data.AssetData.ServerIP != "" { + d.Set("server_ip", dsfDataSourceReadResponse.Data.AssetData.ServerIP) + } if dsfDataSourceReadResponse.Data.AssetData.ServerPort != nil { var serverPort string if serverPortNum, ok := dsfDataSourceReadResponse.Data.AssetData.ServerPort.(float64); ok { @@ -1493,7 +1516,9 @@ func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { //connection["azure_storage_account"] = v.ConnectionData.AzureStorageAccount //connection["azure_storage_container"] = v.ConnectionData.AzureStorageContainer //connection["azure_storage_secret_key"] = v.ConnectionData.AzureStorageSecretKey - connection["base_dn"] = v.ConnectionData.BaseDn + if v.ConnectionData.BaseDn != "" { + connection["base_dn"] = v.ConnectionData.BaseDn + } connection["bucket"] = v.ConnectionData.Bucket connection["ca_certs_path"] = v.ConnectionData.CaCertsPath connection["ca_file"] = v.ConnectionData.CaFile @@ -1565,14 +1590,18 @@ func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { //connection["secure_connection"] = v.ConnectionData.SecureConnection connection["self_signed_cert"] = v.ConnectionData.SelfSignedCert connection["self_signed"] = v.ConnectionData.SelfSigned - connection["server_ip"] = v.ConnectionData.ServerIp + if v.ConnectionData.ServerIp != "" { + connection["server_ip"] = v.ConnectionData.ServerIp + } connection["server_port"] = v.ConnectionData.ServerPort connection["service_key"] = v.ConnectionData.ServiceKey connection["snowflake_role"] = v.ConnectionData.SnowflakeRole connection["ssl_server_cert"] = v.ConnectionData.SslServerCert connection["ssl"] = v.ConnectionData.Ssl //connection["store_aws_credentials"] = v.ConnectionData.StoreAwsCredentials - connection["subscription_id"] = v.ConnectionData.SubscriptionID + if v.ConnectionData.SubscriptionID != "" { + connection["subscription_id"] = v.ConnectionData.SubscriptionID + } connection["tenant_id"] = v.ConnectionData.TenantID connection["thrift_transport"] = v.ConnectionData.ThriftTransport connection["tmp_user"] = v.ConnectionData.TmpUser @@ -1645,42 +1674,60 @@ func resourceDSFDataSourceRead(d *schema.ResourceData, m interface{}) error { return nil } -func resourceDSFDataSourceUpdate(d *schema.ResourceData, m interface{}) error { +func resourceDSFDataSourceUpdateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { + var diags diag.Diagnostics client := m.(*Client) + + // check provided fields against schema dsfDataSourceId := d.Id() if isOk, err := checkResourceRequiredFields(requiredDataSourceFieldsJson, ignoreDataSourceParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload dsfDataSource := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&dsfDataSource, serverType, d) + // auditPullEnabled set to current value from state + auditPullEnabled, _ := d.GetChange("audit_pull_enabled") + dsfDataSource.Data.AssetData.AuditPullEnabled = auditPullEnabled.(bool) + + // update resource log.Printf("[INFO] Updating DSF data source for serverType: %s and gatewayId: %s assetId: %s\n", dsfDataSource.Data.ServerType, dsfDataSource.Data.GatewayID, dsfDataSource.Data.AssetData.AssetID) _, err := client.UpdateDSFDataSource(dsfDataSourceId, dsfDataSource) - if err != nil { log.Printf("[ERROR] Updating data source for serverType: %s and gatewayId: %s assetId: %s | err:%s\n", dsfDataSource.Data.ServerType, dsfDataSource.Data.GatewayID, dsfDataSource.Data.AssetData.AssetID, err) - return err + return diag.FromErr(err) } // Connect/disconnect asset to gateway - connectDisconnectGateway(d, dsfDataSource, m) + err = connectDisconnectGateway(ctx, d, dsfDataSourceResourceType, m) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Warning, + Summary: fmt.Sprintf("Error while updating audit state for asset: %s", d.Get("asset_id")), + Detail: fmt.Sprintf("Error: %s\n", err), + }) + } // Set ID d.SetId(dsfDataSourceId) // Set the rest of the state from the resource read - return resourceDSFDataSourceRead(d, m) + log.Printf("[DEBUG] Writing data source asset details to state") + resourceDSFDataSourceReadContext(ctx, d, m) + + return diags } -func resourceDSFDataSourceDelete(d *schema.ResourceData, m interface{}) error { +func resourceDSFDataSourceDeleteContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) dsfDataSourceId := d.Id() - log.Printf("[INFO] Deleting data-source with dsfDataSourceId: %s", dsfDataSourceId) - - dsfDataSourceDeleteResponse, err := client.DeleteDSFDataSource(dsfDataSourceId) - if dsfDataSourceDeleteResponse != nil { + _, err := client.DeleteDSFDataSource(dsfDataSourceId) + // if an error is returned, assume it has already been deleted + if err != nil { log.Printf("[INFO] DSF data source has already been deleted with dsfDataSourceId: %s | err: %s\n", dsfDataSourceId, err) } return nil @@ -1787,9 +1834,9 @@ func resourceDataSourceConnectionHash(v interface{}) int { // buf.WriteString(fmt.Sprintf("%v-", v.(string))) //} - //if v, ok := m["base_dn"]; ok { - // buf.WriteString(fmt.Sprintf("%v-", v.(string))) - //} + // if v, ok := m["base_dn"]; ok { + // buf.WriteString(fmt.Sprintf("%v-", v.(string))) + // } if v, ok := m["bucket"]; ok { buf.WriteString(fmt.Sprintf("%v-", v.(string))) diff --git a/dsfhub/resource_data_source_test.go b/dsfhub/resource_data_source_test.go index e74bb31..d64275c 100644 --- a/dsfhub/resource_data_source_test.go +++ b/dsfhub/resource_data_source_test.go @@ -2,36 +2,229 @@ package dsfhub import ( "fmt" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" - "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" "log" + "os" "testing" + "time" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" ) -const dsfDataSourceResourceName = "dsf_data_source" -const dsfDataSourceType = "aws-rds-mysql" -const dsfDataSourceResourceTypeAndName = dsfDataSourceResourceName + "." + dsfDataSourceType - -func TestAccDSFDataSource_basic(t *testing.T) { - log.Printf("======================== BEGIN TEST ========================") - log.Printf("[INFO] Running test TestAccDSFDataSource_basic \n") - resource.Test(t, resource.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - Providers: testAccProviders, - CheckDestroy: testAccDSFDataSourceDestroy, +func TestAccDSFDataSource_Basic(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const resourceName = "basic_test_data_source" + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, resourceName) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, Steps: []resource.TestStep{ { - Config: testAccCheckDSFDataSourceConfigBasic(t), + Config: testAccDSFDataSourceConfig_Basic( + resourceName, + testAdminEmail, + testArn, + gatewayId, + testServerHostName, + testDSServerType, + ), Check: resource.ComposeTestCheckFunc( - testCheckDSFDataSourceExists(dsfDataSourceResourceName), - resource.TestCheckResourceAttr(dsfDataSourceResourceTypeAndName, dsfDataSourceResourceName, dsfDataSourceType), + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "false"), ), }, { - ResourceName: dsfDataSourceResourceTypeAndName, + ResourceName: resourceTypeAndName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccDSFDataSource_AwsRdsOracleConnectDisconnectGateway(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const resourceName = "rds_oracle_connect_disconnect_gateway" + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, resourceName) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + // onboard and connect to gateway + { + Config: testAccDSFDataSourceConfig_AwsRdsOracle(resourceName, gatewayId, resourceName, "UNIFIED", "true"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(resourceTypeAndName, "gateway_service", "gateway-odbc@oracle_unified.service"), + ), + }, + // update audit_type -> reconnect asset to gateway + { + Config: testAccDSFDataSourceConfig_AwsRdsOracle(resourceName, gatewayId, resourceName, "UNIFIED_AGGREGATED", "true"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(resourceTypeAndName, "gateway_service", "gateway-odbc@oracle_unified_aggregated.service"), + ), + }, + // disconnect asset + { + Config: testAccDSFDataSourceConfig_AwsRdsOracle(resourceName, gatewayId, resourceName, "UNIFIED_AGGREGATED", "false"), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "false"), + resource.TestCheckResourceAttr(resourceTypeAndName, "gateway_service", ""), + ), + }, + // validate import + { + ResourceName: resourceTypeAndName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccDSFDataSource_AwsRdsAuroraPostgresqlClusterCloudWatch(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const ( + assetId = "arn:aws:rds:us-east-2:123456789012:cluster:my-aurorapostgresql-cluster" + resourceName = "aurora_postgresql_cluster_onboarding" + + instanceAssetId = assetId + "-writer" + instanceResourceName = resourceName + "_instance" + + logGroupAssetId = "arn:aws:logs:us-east-2:123456789012:log-group:/aws/rds/cluster/my-cluster/postgresql:*" + logGroupResourceName = resourceName + "_log_group" + ) + + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, resourceName) + instanceResourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, instanceResourceName) + logGroupResourceTypeAndName := fmt.Sprintf("%s.%s", dsfLogAggregatorResourceType, logGroupResourceName) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + // onboard and connect to gateway + { + Config: testAccDSFDataSourceConfig_AwsRdsAuroraPostgresqlCluster(resourceName, gatewayId, assetId, "LOG_GROUP", resourceName) + + testAccDSFDataSourceConfig_AwsRdsAuroraPostgresql(instanceResourceName, gatewayId, instanceAssetId, resourceName) + + testAccDSFLogAggregatorConfig_AwsLogGroup(logGroupResourceName, gatewayId, logGroupAssetId, resourceTypeAndName+".asset_id", true, "LOG_GROUP", ""), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "gateway_service", "gateway-aws@aurora-postgresql.service"), + ), + }, + // refresh and verify DB assets are connected + { + RefreshState: true, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(instanceResourceTypeAndName, "audit_pull_enabled", "true"), + ), + }, + // validate import + { + ResourceName: resourceTypeAndName, + ImportState: true, + ImportStateVerify: true, + }, + }, + }) +} + +func TestAccDSFDataSource_AwsRdsAuroraMysqlClusterCloudWatchSlowQuery(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const ( + assetId = "arn:aws:rds:us-east-2:123456789012:cluster:my-auroramysql-cluster" + resourceName = "aurora_mysql_cluster_onboarding" + + instanceAssetId = assetId + "-writer" + instanceResourceName = resourceName + "_instance" + + logGroupAssetId = "arn:aws:logs:us-east-2:123456789012:log-group:/aws/rds/cluster/my-aurora-cluster/audit:*" + logGroupResourceName = resourceName + "_log_group" + + slowLogGroupAssetId = "arn:aws:logs:us-east-2:123456789012:log-group:/aws/rds/cluster/my-aurora-cluster/slowquery:*" + slowLogGroupResourceName = resourceName + "_slow_log_group" + ) + + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, resourceName) + //TODO: check that instance asset is connected once fixed: https://onejira.imperva.com/browse/SR-2046 + // instanceResourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, instanceResourceName) + logGroupResourceTypeAndName := fmt.Sprintf("%s.%s", dsfLogAggregatorResourceType, logGroupResourceName) + slowLogGroupResourceTypeAndName := fmt.Sprintf("%s.%s", dsfLogAggregatorResourceType, slowLogGroupResourceName) + + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, + Steps: []resource.TestStep{ + // onboard and connect to gateway + { + Config: testAccDSFDataSourceConfig_AwsRdsAuroraMysqlCluster(resourceName, gatewayId, assetId, "", resourceName) + + testAccDSFDataSourceConfig_AwsRdsAuroraMysql(instanceResourceName, gatewayId, instanceAssetId, resourceName) + + testAccDSFLogAggregatorConfig_AwsLogGroup(logGroupResourceName, gatewayId, logGroupAssetId, resourceTypeAndName+".asset_id", true, "LOG_GROUP", "") + + testAccDSFLogAggregatorConfig_AwsLogGroup(slowLogGroupResourceName, gatewayId, slowLogGroupAssetId, resourceTypeAndName+".asset_id", true, "AWS_RDS_AURORA_MYSQL_SLOW", logGroupResourceTypeAndName), + // verify log group assets are connected + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "gateway_service", "gateway-aws@aurora-mysql.service"), + resource.TestCheckResourceAttr(slowLogGroupResourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(slowLogGroupResourceTypeAndName, "gateway_service", "gateway-aws@aurora-mysql-slow-query.service"), + ), + }, + // refresh and verify DB assets are connected + { + RefreshState: true, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "true"), + // resource.TestCheckResourceAttr(instanceResourceTypeAndName, "audit_pull_enabled", "true"), + ), + }, + // disconnect assets + { + Config: testAccDSFDataSourceConfig_AwsRdsAuroraMysqlCluster(resourceName, gatewayId, assetId, "", resourceName) + + testAccDSFDataSourceConfig_AwsRdsAuroraMysql(instanceResourceName, gatewayId, instanceAssetId, resourceName) + + testAccDSFLogAggregatorConfig_AwsLogGroup(logGroupResourceName, gatewayId, logGroupAssetId, resourceTypeAndName+".asset_id", false, "LOG_GROUP", "") + + testAccDSFLogAggregatorConfig_AwsLogGroup(slowLogGroupResourceName, gatewayId, slowLogGroupAssetId, resourceTypeAndName+".asset_id", false, "AWS_RDS_AURORA_MYSQL_SLOW", logGroupResourceTypeAndName), + // verify log group assets are disconnected + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "audit_pull_enabled", "false"), + resource.TestCheckResourceAttr(logGroupResourceTypeAndName, "gateway_service", ""), + resource.TestCheckResourceAttr(slowLogGroupResourceTypeAndName, "audit_pull_enabled", "false"), + resource.TestCheckResourceAttr(slowLogGroupResourceTypeAndName, "gateway_service", ""), + ), + }, + // refresh and verify DB assets are disconnected + { + RefreshState: true, + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "false"), + // resource.TestCheckResourceAttr(instanceResourceTypeAndName, "audit_pull_enabled", "false"), + ), + }, + // validate import + { + ResourceName: resourceTypeAndName, ImportState: true, ImportStateVerify: true, - ImportStateIdFunc: testAccDSFDataSourceId, }, }, }) @@ -40,7 +233,7 @@ func TestAccDSFDataSource_basic(t *testing.T) { func testAccDSFDataSourceId(state *terraform.State) (string, error) { log.Printf("[INFO] Running test testAccDSFDataSourceId \n") for _, rs := range state.RootModule().Resources { - if rs.Type != dsfDataSourceType { + if rs.Type != dsfDataSourceResourceType { continue } return fmt.Sprintf("%s", rs.Primary.ID), nil @@ -48,53 +241,25 @@ func testAccDSFDataSourceId(state *terraform.State) (string, error) { return "", fmt.Errorf("error finding DSF dataSourceId") } -func testCheckDSFDataSourceExists(dataSourceId string) resource.TestCheckFunc { - log.Printf("[INFO] Running test testCheckDSFDataSourceExists \n") - return func(state *terraform.State) error { - res, ok := state.RootModule().Resources[dataSourceId] - if !ok { - return fmt.Errorf("DSF Data Source resource not found by dataSourceId: %s", dataSourceId) - } - serverType, ok := res.Primary.Attributes["server_type"] - if !ok || serverType == "" { - return fmt.Errorf("DSF Data Source Server Type does not exist for dataSourceId %s", dataSourceId) - } - client := testAccProvider.Meta().(*Client) - _, err := client.ReadDSFDataSource(res.Primary.ID) - if err != nil { - return fmt.Errorf("DSF Data Source Server Type: %s (dataSourceId: %s) does not exist", serverType, dataSourceId) - } - return nil - } -} - -func testAccCheckDSFDataSourceConfigBasic(t *testing.T) string { - log.Printf("[INFO] Running test testAccCheckDSFDataSourceConfigBasic \n") - return fmt.Sprintf(` -resource "%s" "my_test_data_source" { - admin_email = "%s" - arn = "%s" - asset_display_name = "%s" - gateway_id = %s - server_host_name = "%s" - server_type = "%s" -}`, dsfDataSourceResourceName, testAdminEmail, testArn, testAssetDisplayName, testGatewayId, testServerHostName, testDSServerType) -} - +// Confirm assets are destroyed after an acceptance test run func testAccDSFDataSourceDestroy(state *terraform.State) error { - log.Printf("[INFO] Running test testAccDSFDataSourceDestroy \n") + log.Printf("[INFO] Running test testAccDSFDataSourceDestroy") + // allow "disableAsset" playbook enough time to run + time.Sleep(5 + time.Second) + + // check if asset still exists on hub client := testAccProvider.Meta().(*Client) for _, res := range state.RootModule().Resources { - if res.Type != "dsf_data_source" { + if res.Type != dsfDataSourceResourceType { continue } - dsfDataSourceId := res.Primary.ID - readDSFDataSourceResponse, err := client.ReadDSFDataSource(dsfDataSourceId) + assetId := res.Primary.ID + readDSFDataSourceResponse, err := client.ReadDSFDataSource(assetId) if readDSFDataSourceResponse.Errors == nil { - return fmt.Errorf("DSF Data Source %s should have received an error in the response", dsfDataSourceId) + return fmt.Errorf("DSF Data Source %s should have received an error in the response", assetId) } if err == nil { - return fmt.Errorf("DSF Data Source %s still exists for gatewayId: %s", dsfDataSourceId, testGatewayId) + return fmt.Errorf("DSF Data Source %s still exists", assetId) } } return nil diff --git a/dsfhub/resource_data_source_test_data.go b/dsfhub/resource_data_source_test_data.go new file mode 100644 index 0000000..b9ed30e --- /dev/null +++ b/dsfhub/resource_data_source_test_data.go @@ -0,0 +1,157 @@ +package dsfhub + +import "fmt" + +// Output a terraform config for a basic data source resource. +func testAccDSFDataSourceConfig_Basic(resourceName string, adminEmail string, assetId string, gatewayId string, serverHostName string, serverType string) string { + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + admin_email = "%[2]s" + asset_id = "%[3]s" + asset_display_name = "%[3]s" + gateway_id = "%[4]s" + server_host_name = "%[5]s" + server_type = "%[6]s" +}`, resourceName, adminEmail, assetId, gatewayId, serverHostName, serverType) +} + +// Output a terraform config for an AWS RDS ORACLE data source resource. +func testAccDSFDataSourceConfig_AwsRdsOracle(resourceName string, gatewayId string, assetId string, auditType string, auditPullEnabled string) string { + // convert audit_pull_enabled to "null" if empty + if auditPullEnabled == "" { + auditPullEnabled = "null" + } + + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + server_type = "AWS RDS ORACLE" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + audit_pull_enabled = %[5]s + audit_type = "%[4]s" + gateway_id = "%[2]s" + server_host_name = "test.com" + server_port = "1521" + service_name = "ORCL" + + asset_connection { + auth_mechanism = "password" + password = "password" + reason = "default" + username = "username" + } +} +`, resourceName, gatewayId, assetId, auditType, auditPullEnabled) +} + +// Output a terraform config for an AWS RDS AURORA POSTGRESQL CLUSTER data +// source resource. +func testAccDSFDataSourceConfig_AwsRdsAuroraPostgresqlCluster(resourceName string, gatewayId string, assetId string, auditType string, clusterId string) string { + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + server_type = "AWS RDS AURORA POSTGRESQL CLUSTER" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + audit_type = "%[4]s" + cluster_id = "%[5]s" + cluster_name = "%[5]s" + gateway_id = "%[2]s" + region = "us-east-2" + server_host_name = "my-cluster.cluster-xxxxk8rsfzja.us-east-2.rds.amazonaws.com" + server_port = "5432" + + asset_connection { + auth_mechanism = "password" + password = "my-password" + reason = "default" + username = "my-user" + } +} +`, resourceName, gatewayId, assetId, auditType, clusterId) +} + +// Output a terraform config for an AWS RDS AURORA POSTGRESQL data source +// reource. +func testAccDSFDataSourceConfig_AwsRdsAuroraPostgresql(resourceName string, gatewayId string, assetId string, clusterId string) string { + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + server_type = "AWS RDS AURORA POSTGRESQL" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + cluster_id = "%[4]s" + cluster_name = "%[4]s" + gateway_id = "%[2]s" + region = "us-east-2" + server_host_name = "my-cluster.cluster-xxxxk8rsfzja.us-east-2.rds.amazonaws.com" + server_port = "5432" + + asset_connection { + auth_mechanism = "password" + password = "my-password" + reason = "default" + username = "my-user" + } +} +`, resourceName, gatewayId, assetId, clusterId) +} + +// Output a terraform config for an AWS RDS AURORA MYSQL CLUSTER data source +// resource. +func testAccDSFDataSourceConfig_AwsRdsAuroraMysqlCluster(resourceName string, gatewayId string, assetId string, auditType string, clusterId string) string { + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + server_type = "AWS RDS AURORA MYSQL CLUSTER" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + audit_type = "%[4]s" + cluster_id = "%[5]s" + cluster_name = "%[5]s" + gateway_id = "%[2]s" + region = "us-east-2" + server_host_name = "my-cluster.cluster-xxxxk8rsfzja.us-east-2.rds.amazonaws.com" + server_port = "3306" + + asset_connection { + auth_mechanism = "password" + password = "my-password" + reason = "default" + username = "my-user" + } +} +`, resourceName, gatewayId, assetId, auditType, clusterId) +} + +// Output a terraform config for an AWS RDS AURORA MYSQL data source resource. +func testAccDSFDataSourceConfig_AwsRdsAuroraMysql(resourceName string, gatewayId string, assetId string, clusterId string) string { + return fmt.Sprintf(` +resource "`+dsfDataSourceResourceType+`" "%[1]s" { + server_type = "AWS RDS AURORA MYSQL" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + #TODO: re-add cluster fields when supported by USC: https://onejira.imperva.com/browse/USC-2389 + #cluster_id = "%[4]s" + #cluster_name = "%[4]s" + gateway_id = "%[2]s" + region = "us-east-2" + server_host_name = "my-cluster.cluster-xxxxk8rsfzja.us-east-2.rds.amazonaws.com" + server_port = "5432" + + asset_connection { + auth_mechanism = "password" + password = "my-password" + reason = "default" + username = "my-user" + } +} +`, resourceName, gatewayId, assetId, clusterId) +} diff --git a/dsfhub/resource_log_aggregator.go b/dsfhub/resource_log_aggregator.go index 56be452..2630354 100644 --- a/dsfhub/resource_log_aggregator.go +++ b/dsfhub/resource_log_aggregator.go @@ -2,18 +2,21 @@ package dsfhub import ( "bytes" + "context" "fmt" + "log" + + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - "log" ) func resourceLogAggregator() *schema.Resource { return &schema.Resource{ - Create: resourceLogAggregatorCreate, - Read: resourceLogAggregatorRead, - Update: resourceLogAggregatorUpdate, - Delete: resourceLogAggregatorDelete, + CreateContext: resourceLogAggregatorCreateContext, + ReadContext: resourceLogAggregatorReadContext, + UpdateContext: resourceLogAggregatorUpdateContext, + DeleteContext: resourceLogAggregatorDeleteContext, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, @@ -625,36 +628,52 @@ func resourceLogAggregator() *schema.Resource { } } -func resourceLogAggregatorCreate(d *schema.ResourceData, m interface{}) error { +func resourceLogAggregatorCreateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { + var diags diag.Diagnostics client := m.(*Client) + + // check provided fields against schema if isOk, err := checkResourceRequiredFields(requiredLogAggregatorJson, ignoreLogAggregatorParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + // convert provided fields into API payload logAggregator := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&logAggregator, serverType, d) + // auditPullEnabled set to false as connect/disconnect logic handled below logAggregator.Data.AssetData.AuditPullEnabled = false + + // create resource log.Printf("[INFO] Creating LogAggregator for serverType: %s and gatewayId: %s\n", logAggregator.Data.ServerType, logAggregator.Data.GatewayID) createLogAggregatorResponse, err := client.CreateLogAggregator(logAggregator) - if err != nil { log.Printf("[ERROR] adding LogAggregator for serverType: %s and gatewayId: %s | err: %s", serverType, logAggregator.Data.GatewayID, err) - return err + return diag.FromErr(err) } // Connect/disconnect asset to gateway - connectDisconnectGateway(d, logAggregator, m) + err = connectDisconnectGateway(ctx, d, dsfLogAggregatorResourceType, m) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Warning, + Summary: fmt.Sprintf("Error while updating audit state for asset: %s", d.Get("asset_id")), + Detail: fmt.Sprintf("Error: %s\n", err), + }) + } // Set ID logAggregatorId := createLogAggregatorResponse.Data.ID d.SetId(logAggregatorId) // Set the rest of the state from the resource read - return resourceLogAggregatorRead(d, m) + log.Printf("[DEBUG] Writing log aggregator asset details to state") + resourceLogAggregatorReadContext(ctx, d, m) + + return diags } -func resourceLogAggregatorRead(d *schema.ResourceData, m interface{}) error { +func resourceLogAggregatorReadContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) logAggregatorId := d.Id() @@ -664,7 +683,7 @@ func resourceLogAggregatorRead(d *schema.ResourceData, m interface{}) error { if err != nil { log.Printf("[ERROR] Reading logAggregatorReadResponse with logAggregatorId: %s | err: %s\n", logAggregatorId, err) - return err + return diag.FromErr(err) } if logAggregatorReadResponse != nil { @@ -683,7 +702,9 @@ func resourceLogAggregatorRead(d *schema.ResourceData, m interface{}) error { d.Set("audit_type", logAggregatorReadResponse.Data.AssetData.AuditType) d.Set("available_regions", logAggregatorReadResponse.Data.AssetData.AvailableRegions) d.Set("bucket_account_id", logAggregatorReadResponse.Data.AssetData.BucketAccountId) - d.Set("credential_endpoint", logAggregatorReadResponse.Data.AssetData.CredentialsEndpoint) + if logAggregatorReadResponse.Data.AssetData.CredentialsEndpoint != "" { + d.Set("credential_endpoint", logAggregatorReadResponse.Data.AssetData.CredentialsEndpoint) + } d.Set("criticality", logAggregatorReadResponse.Data.AssetData.Criticality) d.Set("gateway_id", logAggregatorReadResponse.Data.GatewayID) d.Set("gateway_service", logAggregatorReadResponse.Data.AssetData.GatewayService) @@ -799,42 +820,61 @@ func resourceLogAggregatorRead(d *schema.ResourceData, m interface{}) error { connections.Add(connection) } - d.Set("ca_connection", connections) + d.Set("asset_connection", connections) log.Printf("[INFO] Finished reading logAggregator with logAggregatorId: %s\n", logAggregatorId) return nil } -func resourceLogAggregatorUpdate(d *schema.ResourceData, m interface{}) error { +func resourceLogAggregatorUpdateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { + var diags diag.Diagnostics client := m.(*Client) + + // check provided fields against schema logAggregatorId := d.Id() if isOk, err := checkResourceRequiredFields(requiredLogAggregatorJson, ignoreLogAggregatorParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload logAggregator := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&logAggregator, serverType, d) + // auditPullEnabled set to current value from state + auditPullEnabled, _ := d.GetChange("audit_pull_enabled") + logAggregator.Data.AssetData.AuditPullEnabled = auditPullEnabled.(bool) + + // update resource log.Printf("[INFO] Updating LogAggregator for serverType: %s and gatewayId: %s assetId: %s\n", logAggregator.Data.ServerType, logAggregator.Data.GatewayID, logAggregator.Data.AssetData.AssetID) _, err := client.UpdateLogAggregator(logAggregatorId, logAggregator) - if err != nil { log.Printf("[ERROR] Updating LogAggregator for serverType: %s and gatewayId: %s assetId: %s | err:%s\n", logAggregator.Data.ServerType, logAggregator.Data.GatewayID, logAggregator.Data.AssetData.AssetID, err) - return err + return diag.FromErr(err) } // Connect/disconnect asset to gateway - connectDisconnectGateway(d, logAggregator, m) + err = connectDisconnectGateway(ctx, d, dsfLogAggregatorResourceType, m) + if err != nil { + diags = append(diags, diag.Diagnostic{ + Severity: diag.Warning, + Summary: fmt.Sprintf("Error while updating audit state for asset: %s", d.Get("asset_id")), + Detail: fmt.Sprintf("Error: %s\n", err), + }) + } // Set ID d.SetId(logAggregatorId) // Set the rest of the state from the resource read - return resourceLogAggregatorRead(d, m) + log.Printf("[DEBUG] Writing log aggregator asset details to state") + resourceLogAggregatorReadContext(ctx, d, m) + + return diags } -func resourceLogAggregatorDelete(d *schema.ResourceData, m interface{}) error { +func resourceLogAggregatorDeleteContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) logAggregatorId := d.Id() diff --git a/dsfhub/resource_log_aggregator_test.go b/dsfhub/resource_log_aggregator_test.go index c0bcb31..6eab665 100644 --- a/dsfhub/resource_log_aggregator_test.go +++ b/dsfhub/resource_log_aggregator_test.go @@ -3,36 +3,48 @@ package dsfhub import ( "fmt" "log" + "os" + "regexp" "testing" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" ) -const logAggregatorResourceName = "log_aggregator" -const logAggregatorType = "aws" -const logAggregatorResourceTypeAndName = logAggregatorResourceName + "." + logAggregatorType +func TestAccDSFLogAggregator_AwsLogGroup(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const ( + assetId = "arn:aws:logs:us-east-2:123456789012:log-group:/aws/rds/instance/my-database/audit:*" + resourceName = "my-database-log-group" + serverHostName = "oracle-rds-db.xxxxx8rsfzja.us-east-2.rds.amazonaws.com" + parentAssetId = "arn:aws:rds:us-east-2:123456789012:db:oracle-rds-db" + parentResourceName = "my-oracle-db" + ) + + resourceTypeAndName := fmt.Sprintf("%s.%s", dsfLogAggregatorResourceType, resourceName) + parentResourceTypeAndName := fmt.Sprintf("%s.%s", dsfDataSourceResourceType, parentResourceName) -func TestAccLogAggregator_basic(t *testing.T) { - log.Printf("======================== BEGIN TEST ========================") - log.Printf("[INFO] Running test TestAccLogAggregator_basic \n") - resource.Test(t, resource.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - Providers: testAccProviders, - CheckDestroy: testAccLogAggregatorDestroy, + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, Steps: []resource.TestStep{ + // Failed: missing parent_asset_id { - Config: testAccCheckLogAggregatorConfigBasic(t), - Check: resource.ComposeTestCheckFunc( - testCheckLogAggregatorExists(logAggregatorResourceName), - resource.TestCheckResourceAttr(logAggregatorResourceTypeAndName, logAggregatorResourceName, logAggregatorType), - ), + Config: testAccDSFLogAggregatorConfig_AwsLogGroup(resourceName, gatewayId, assetId, "", true, "LOG_GROUP", ""), + ExpectError: regexp.MustCompile("Error: missing required fields for dsfhub_data_source"), }, + // Onboard with AWS parent asset { - ResourceName: logAggregatorResourceTypeAndName, - ImportState: true, - ImportStateVerify: true, - ImportStateIdFunc: testAccLogAggregatorId, + Config: testAccDSFDataSourceConfig_AwsRdsOracle(parentResourceName, gatewayId, parentAssetId, "LOG_GROUP", "") + + testAccDSFLogAggregatorConfig_AwsLogGroup(resourceName, gatewayId, assetId, parentResourceTypeAndName+".asset_id", true, "LOG_GROUP", ""), + Check: resource.ComposeTestCheckFunc( + resource.TestCheckResourceAttr(resourceTypeAndName, "audit_pull_enabled", "true"), + resource.TestCheckResourceAttr(resourceTypeAndName, "gateway_service", "gateway-aws@oracle-rds.service"), + ), }, }, }) @@ -41,7 +53,7 @@ func TestAccLogAggregator_basic(t *testing.T) { func testAccLogAggregatorId(state *terraform.State) (string, error) { log.Printf("[INFO] Running test testAccLogAggregatorId \n") for _, rs := range state.RootModule().Resources { - if rs.Type != logAggregatorType { + if rs.Type != dsfLogAggregatorResourceType { continue } return fmt.Sprintf("%s", rs.Primary.ID), nil @@ -69,24 +81,11 @@ func testCheckLogAggregatorExists(dataSourceId string) resource.TestCheckFunc { } } -func testAccCheckLogAggregatorConfigBasic(t *testing.T) string { - log.Printf("[INFO] Running test testAccCheckLogAggregatorConfigBasic \n") - return fmt.Sprintf(` -resource "%s" "my_test_data_source" { - admin_email = "%s" - arn = "%s" - asset_display_name = "%s" - gateway_id = %s - server_host_name = "%s" - server_type = "%s" -}`, logAggregatorResourceName, testAdminEmail, testArn, testAssetDisplayName, testGatewayId, testServerHostName, testDSServerType) -} - func testAccLogAggregatorDestroy(state *terraform.State) error { log.Printf("[INFO] Running test testAccLogAggregatorDestroy \n") client := testAccProvider.Meta().(*Client) for _, res := range state.RootModule().Resources { - if res.Type != "dsf_data_source" { + if res.Type != "dsfhub_log_aggregator" { continue } logAggregatorId := res.Primary.ID diff --git a/dsfhub/resource_log_aggregator_test_data.go b/dsfhub/resource_log_aggregator_test_data.go new file mode 100644 index 0000000..f4c2f86 --- /dev/null +++ b/dsfhub/resource_log_aggregator_test_data.go @@ -0,0 +1,30 @@ +package dsfhub + +import "fmt" + +// Output a terraform config for an AWS LOG GROUP log aggregator resource. +func testAccDSFLogAggregatorConfig_AwsLogGroup(resourceName string, gatewayId string, assetId string, parentAssetId string, auditPullEnabled bool, auditType string, dependsOn string) string { + // handle reference to other assets + parentAssetIdVal := testAccParseResourceAttributeReference(parentAssetId) + + return fmt.Sprintf(` +resource "`+dsfLogAggregatorResourceType+`" "%[1]s" { + depends_on = [`+dependsOn+`] + server_type = "AWS LOG GROUP" + + admin_email = "`+testAdminEmail+`" + arn = "%[3]s" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + audit_pull_enabled = %[5]t + audit_type = "%[6]s" + gateway_id = "%[2]s" + parent_asset_id = `+parentAssetIdVal+` + + asset_connection { + auth_mechanism = "default" + reason = "default" + region = "us-east-2" + } +}`, resourceName, gatewayId, assetId, parentAssetId, auditPullEnabled, auditType) +} diff --git a/dsfhub/resource_log_secret_manager_test_data.go b/dsfhub/resource_log_secret_manager_test_data.go new file mode 100644 index 0000000..942b9b4 --- /dev/null +++ b/dsfhub/resource_log_secret_manager_test_data.go @@ -0,0 +1,49 @@ +package dsfhub + +import ( + "fmt" + "testing" +) + +// Output a terraform config for a basic secret manager resource. +func testAccSecretManagerConfig_Basic(t *testing.T) string { + return fmt.Sprintf(` +resource "%s" "my_test_data_source" { + admin_email = "%s" + asset_display_name = "%s" + asset_id = "%s" + gateway_id = "%s" + server_host_name = "%s" + server_ip = "%s" + server_port = "%s" + server_type = "%s" + asset_connection { + reason = "%s" + auth_mechanism = "%s" + role_name = "%s" + } +}`, dsfSecretManagerResourceType, testAdminEmail, testAssetDisplayName, testSMAssetId, testGatewayId, testServerHostName, testServerIP, testServerPort, testSMServerType, testSMConnectionReason, testSMAuthMechanism, testSMRoleName) +} + +// Output a terraform config for a HASHICORP secret manager resource. +func testAccDSFSecretManagerConfig_Hashicorp(resourceName string, gatewayId string, assetId string, serverHostName string, serverPort string, authMechanism string, roleName string) string { + return fmt.Sprintf(` +resource "`+dsfSecretManagerResourceType+`" "%[1]s" { + server_type = "HASHICORP" + + admin_email = "`+testAdminEmail+`" + asset_display_name = "%[3]s" + asset_id = "%[3]s" + gateway_id = "%[2]s" + server_host_name = "%[4]s" + server_ip = "%[4]s" + server_port = "%[5]s" + + asset_connection { + reason = "default" + auth_mechanism = "%[6]s" + role_name = "%[7]s" + } +}`, + resourceName, gatewayId, assetId, serverHostName, serverPort, authMechanism, roleName) +} diff --git a/dsfhub/resource_secret_manager.go b/dsfhub/resource_secret_manager.go index d230cd0..1a31e73 100644 --- a/dsfhub/resource_secret_manager.go +++ b/dsfhub/resource_secret_manager.go @@ -2,18 +2,21 @@ package dsfhub import ( "bytes" + "context" "fmt" + "log" + + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - "log" ) func resourceSecretManager() *schema.Resource { return &schema.Resource{ - Create: resourceSecretManagerCreate, - Read: resourceSecretManagerRead, - Update: resourceSecretManagerUpdate, - Delete: resourceSecretManagerDelete, + CreateContext: resourceSecretManagerCreateContext, + ReadContext: resourceSecretManagerReadContext, + UpdateContext: resourceSecretManagerUpdateContext, + DeleteContext: resourceSecretManagerDeleteContext, Importer: &schema.ResourceImporter{ StateContext: schema.ImportStatePassthroughContext, }, @@ -496,31 +499,38 @@ func resourceSecretManager() *schema.Resource { } } -func resourceSecretManagerCreate(d *schema.ResourceData, m interface{}) error { +func resourceSecretManagerCreateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) + + // check provided fields against schema if isOk, err := checkResourceRequiredFields(requiredSecretManagerFieldsJson, ignoreSecretManagerParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload secretManager := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&secretManager, serverType, d) + // create resource log.Printf("[INFO] Creating SecretManager for serverType: %s and gatewayId: %s gatewayId: \n", serverType, secretManager.Data.GatewayID) createSecretManagerResponse, err := client.CreateSecretManager(secretManager) - if err != nil { log.Printf("[ERROR] adding secret manager for serverType: %s and gatewayId: %s | err: %s\n", serverType, secretManager.Data.GatewayID, err) - return err + return diag.FromErr(err) } + // set ID secretManagerId := createSecretManagerResponse.Data.ID d.SetId(secretManagerId) // Set the rest of the state from the resource read - return resourceSecretManagerRead(d, m) + resourceSecretManagerReadContext(ctx, d, m) + + return nil } -func resourceSecretManagerRead(d *schema.ResourceData, m interface{}) error { +func resourceSecretManagerReadContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) secretManagerId := d.Id() @@ -530,7 +540,7 @@ func resourceSecretManagerRead(d *schema.ResourceData, m interface{}) error { if err != nil { log.Printf("[ERROR] Reading secretManagerReadResponse with secretManagerId: %s | err: %s\n", secretManagerId, err) - return err + return diag.FromErr(err) } if secretManagerReadResponse != nil { @@ -658,31 +668,38 @@ func resourceSecretManagerRead(d *schema.ResourceData, m interface{}) error { return nil } -func resourceSecretManagerUpdate(d *schema.ResourceData, m interface{}) error { +func resourceSecretManagerUpdateContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) + + // check provided fields against schema secretManagerId := d.Id() if isOk, err := checkResourceRequiredFields(requiredSecretManagerFieldsJson, ignoreSecretManagerParamsByServerType, d); !isOk { - return err + return diag.FromErr(err) } + + // convert provided fields into API payload secretManager := ResourceWrapper{} serverType := d.Get("server_type").(string) createResource(&secretManager, serverType, d) + // update resource log.Printf("[INFO] Updating DSF data source for serverType: %s and gatewayId: %s assetId: %s\n", secretManager.Data.ServerType, secretManager.Data.GatewayID, secretManager.Data.AssetData.AssetID) _, err := client.UpdateSecretManager(secretManagerId, secretManager) - if err != nil { log.Printf("[ERROR] Updating secret manager for serverType: %s and gatewayId: %s assetId: %s | err:%s\n", secretManager.Data.ServerType, secretManager.Data.GatewayID, secretManager.Data.AssetData.AssetID, err) - return err + return diag.FromErr(err) } + // set ID d.SetId(secretManagerId) // Set the rest of the state from the resource read - return resourceSecretManagerRead(d, m) + resourceSecretManagerReadContext(ctx, d, m) + + return nil } -func resourceSecretManagerDelete(d *schema.ResourceData, m interface{}) error { +func resourceSecretManagerDeleteContext(ctx context.Context, d *schema.ResourceData, m interface{}) diag.Diagnostics { client := m.(*Client) secretManagerId := d.Id() diff --git a/dsfhub/resource_secret_manager_test.go b/dsfhub/resource_secret_manager_test.go index f7ffc12..e9a63f7 100644 --- a/dsfhub/resource_secret_manager_test.go +++ b/dsfhub/resource_secret_manager_test.go @@ -2,36 +2,32 @@ package dsfhub import ( "fmt" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" - "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" "log" + "os" "testing" + + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/resource" + "github.com/hashicorp/terraform-plugin-sdk/v2/terraform" ) -const secretManagerResourceName = "secret_manager" -const secretManagerType = "HASHICORP" -const secretManagerResourceTypeAndName = secretManagerResourceName + "." + secretManagerType +func TestAccDSFSecretManager_Hashicorp(t *testing.T) { + gatewayId := os.Getenv("GATEWAY_ID") + if gatewayId == "" { + t.Skip("GATEWAY_ID environment variable must be set") + } + + const ( + serverPort = "8200" + assetId = testOnPremServerHostName + ":HASHICORP::" + serverPort + resourceName = "example-hashicorp" + ) -func TestAccSecretManager_basic(t *testing.T) { - log.Printf("======================== BEGIN TEST ========================") - log.Printf("[INFO] Running test TestAccSecretManager_basic \n") - resource.Test(t, resource.TestCase{ - PreCheck: func() { testAccPreCheck(t) }, - Providers: testAccProviders, - CheckDestroy: testAccSecretManagerDestroy, + resource.ParallelTest(t, resource.TestCase{ + PreCheck: func() { testAccPreCheck(t) }, + Providers: testAccProviders, Steps: []resource.TestStep{ { - Config: testAccCheckSecretManagerConfigBasic(t), - Check: resource.ComposeTestCheckFunc( - testCheckSecretManagerExists(secretManagerResourceName), - resource.TestCheckResourceAttr(secretManagerResourceTypeAndName, secretManagerResourceName, secretManagerType), - ), - }, - { - ResourceName: secretManagerResourceTypeAndName, - ImportState: true, - ImportStateVerify: true, - ImportStateIdFunc: testAccSecretManagerId, + Config: testAccDSFSecretManagerConfig_Hashicorp(resourceName, gatewayId, assetId, testOnPremServerHostName, serverPort, "ec2", "vault-role-for-ec2"), }, }, }) @@ -40,7 +36,7 @@ func TestAccSecretManager_basic(t *testing.T) { func testAccSecretManagerId(state *terraform.State) (string, error) { log.Printf("[INFO] Running test testAccSecretManagerId \n") for _, rs := range state.RootModule().Resources { - if rs.Type != dsfDataSourceType { + if rs.Type != dsfSecretManagerResourceType { continue } return fmt.Sprintf("%s", rs.Primary.ID), nil @@ -68,26 +64,6 @@ func testCheckSecretManagerExists(secretManagerId string) resource.TestCheckFunc } } -func testAccCheckSecretManagerConfigBasic(t *testing.T) string { - log.Printf("[INFO] Running test testAccCheckSecretManagerConfigBasic \n") - return fmt.Sprintf(` -resource "%s" "my_test_data_source" { - admin_email = "%s" - asset_display_name = "%s" - asset_id = "%s" - gateway_id = "%s"" - server_host_name = "%s" - server_ip = "%s" - server_port = "%s" - server_type = "%s" - sm_connection { - reason = "%s" - auth_mechanism = "%s" - role_name = "%s" - } -}`, secretManagerResourceName, testAdminEmail, testAssetDisplayName, testSMAssetId, testGatewayId, testServerHostName, testServerIP, testServerPort, testSMServerType, testSMConnectionReason, testSMAuthMechanism, testSMRoleName) -} - func testAccSecretManagerDestroy(state *terraform.State) error { log.Printf("[INFO] Running test testAccDSFDataSourceDestroy \n") client := testAccProvider.Meta().(*Client) diff --git a/dsfhub/test_constants.go b/dsfhub/test_constants.go index 7c74a6c..0e27dbd 100644 --- a/dsfhub/test_constants.go +++ b/dsfhub/test_constants.go @@ -1,19 +1,23 @@ package dsfhub -const testInvalidDSFHUBHost = "https://invalid.host.com" -const testAdminEmail = "test@email.com" -const testArn = "arn:aws:rds:us-east-2:123456789:db:your-db" -const testServerHostName = "your-db-name.abcde12345.us-east-2.rds.amazonaws.com" -const testServerIP = "1.2.3.4" -const testServerPort = "8200" -const testDSServerType = "AWS RDS MYSQL" -const testGatewayId = "e33bfbe4-a93a-c4e5-8e9c-6e5558c2e2cd" +const ( + testAdminEmail = "test@email.com" + testArn = "arn:aws:rds:us-east-2:123456789:db:your-db" + testServerHostName = "your-db-name.abcde12345.us-east-2.rds.amazonaws.com" + testOnPremServerHostName = "server.company.com" + testServerIP = "1.2.3.4" + testServerPort = "8200" -const testAssetDisplayName = "arn:aws:rds:us-east-2:123456789:db:your-db" + testInvalidDSFHUBHost = "https://invalid.host.com" + testDSServerType = "AWS RDS MYSQL" + testGatewayId = "e33bfbe4-a93a-c4e5-8e9c-6e5558c2e2cd" -const testSMConnectionReason = "default" -const testSMRoleName = "vault-role-for-ec2" -const testSMAuthMechanism = "ec2" -const testSMAssetId = "your-host-name-here" + testAssetDisplayName = "arn:aws:rds:us-east-2:123456789:db:your-db" -const testSMServerType = "HASHICORP" + testSMConnectionReason = "default" + testSMRoleName = "vault-role-for-ec2" + testSMAuthMechanism = "ec2" + testSMAssetId = "your-host-name-here" + + testSMServerType = "HASHICORP" +)