Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
139523: roachtest: add provider specific flags to `roachtest run` r=jeffswenson a=jeffswenson

Previously, `ConfigureClusterFlags` was a method on the provider options, but it applied settings to the underlying provider. Now, these flags were moved to the `ConfigureProviderFlags` method on the provider instance.

The motivation for this change is the AWS profile configuration. The AWS profile is needed to support running roachtests with AWS credentials provisioned via SSO.

The static AWS and GCP providers were removed and code that mutates them
was moved to methods on the providers themselves. The `InfraProvider`
interface was added to the GCP package in order to provide access to
methods on the provider that are used for clusters on multiple clouds.

Release note: none
Epic: none

139777: builtins: fix to_reg* handling of number arguments r=rafiss a=rafiss

fixes #124908
Release note (bug fix): The to_regclass, to_regtype, to_regrole, and related functions now return NULL for any numerical input argument.

Co-authored-by: Jeff Swenson <[email protected]>
Co-authored-by: Rafi Shamim <[email protected]>
  • Loading branch information
3 people committed Jan 24, 2025
3 parents ebb1d25 + 8969b65 + dbe1971 commit 38ec124
Show file tree
Hide file tree
Showing 18 changed files with 199 additions and 110 deletions.
4 changes: 2 additions & 2 deletions pkg/cmd/roachprod/cli/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ func buildSSHKeysListCmd() *cobra.Command {
Use: "list",
Short: "list every SSH public key installed on clusters managed by roachprod",
Run: wrap(func(cmd *cobra.Command, args []string) error {
authorizedKeys, err := gce.GetUserAuthorizedKeys()
authorizedKeys, err := gce.Infrastructure.GetUserAuthorizedKeys()
if err != nil {
return err
}
Expand Down Expand Up @@ -1168,7 +1168,7 @@ func buildSSHKeysRemoveCmd() *cobra.Command {
Run: wrap(func(cmd *cobra.Command, args []string) error {
user := args[0]

existingKeys, err := gce.GetUserAuthorizedKeys()
existingKeys, err := gce.Infrastructure.GetUserAuthorizedKeys()
if err != nil {
return fmt.Errorf("failed to fetch existing keys: %w", err)
}
Expand Down
17 changes: 9 additions & 8 deletions pkg/cmd/roachprod/cli/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,12 @@ func initCreateCmdFlags(createCmd *cobra.Command) {

// Allow each Provider to inject additional configuration flags
for _, providerName := range vm.AllProviderNames() {
if vm.Providers[providerName].Active() {
provider := vm.Providers[providerName]
if provider.Active() {
providerOptsContainer[providerName].ConfigureCreateFlags(createCmd.Flags())
// createCmd only accepts a single GCE project, as opposed to all the other
// commands.
providerOptsContainer[providerName].ConfigureClusterFlags(createCmd.Flags(), vm.SingleProject)
provider.ConfigureProviderFlags(createCmd.Flags(), vm.SingleProject)
}
}
}
Expand All @@ -170,7 +171,8 @@ func initClusterFlagsForMultiProjects(
rootCmd *cobra.Command, excludeFromClusterFlagsMulti []*cobra.Command,
) {
for _, providerName := range vm.AllProviderNames() {
if vm.Providers[providerName].Active() {
provider := vm.Providers[providerName]
if provider.Active() {
for _, cmd := range rootCmd.Commands() {
excludeCmd := false
for _, c := range excludeFromClusterFlagsMulti {
Expand All @@ -182,7 +184,7 @@ func initClusterFlagsForMultiProjects(
if excludeCmd {
continue
}
providerOptsContainer[providerName].ConfigureClusterFlags(cmd.Flags(), vm.AcceptMultipleProjects)
provider.ConfigureProviderFlags(cmd.Flags(), vm.AcceptMultipleProjects)
}
}
}
Expand Down Expand Up @@ -387,10 +389,9 @@ func initGCCmdFlags(gcCmd *cobra.Command) {
"dry-run", "n", dryrun, "dry run (don't perform any actions)")
gcCmd.Flags().StringVar(&config.SlackToken, "slack-token", "", "Slack bot token")
// Allow each Provider to inject additional configuration flags
for _, providerName := range vm.AllProviderNames() {
if vm.Providers[providerName].Active() {
// set up cluster cleanup flag for gcCmd
providerOptsContainer[providerName].ConfigureClusterCleanupFlags(gcCmd.Flags())
for _, provider := range vm.Providers {
if provider.Active() {
provider.ConfigureClusterCleanupFlags(gcCmd.Flags())
}
}
}
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/roachtest/roachtestflags/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ go_library(
visibility = ["//visibility:public"],
deps = [
"//pkg/cmd/roachtest/spec",
"//pkg/roachprod/vm",
"//pkg/util/randutil",
"@com_github_cockroachdb_errors//:errors",
"@com_github_spf13_pflag//:pflag",
Expand Down
4 changes: 4 additions & 0 deletions pkg/cmd/roachtest/roachtestflags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"time"

"github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/cockroachdb/cockroach/pkg/util/randutil"
"github.com/spf13/pflag"
)
Expand Down Expand Up @@ -584,6 +585,9 @@ func AddListFlags(cmdFlags *pflag.FlagSet) {
// command flag set.
func AddRunFlags(cmdFlags *pflag.FlagSet) {
globalMan.AddFlagsToCommand(runCmdID, cmdFlags)
for _, provider := range vm.Providers {
provider.ConfigureProviderFlags(cmdFlags, vm.SingleProject)
}
}

// AddRunOpsFlags adds all flags registered for the run-operations command to
Expand Down
6 changes: 3 additions & 3 deletions pkg/roachprod/roachprod.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ func Sync(l *logger.Logger, options vm.ListOptions) (*cloud.Cloud, error) {
if !config.Quiet {
l.Printf("Refreshing DNS entries...")
}
if err := gce.SyncDNS(l, vms); err != nil {
if err := gce.Infrastructure.SyncDNS(l, vms); err != nil {
l.Errorf("failed to update DNS: %v", err)
}
} else {
Expand Down Expand Up @@ -708,7 +708,7 @@ func SetupSSH(ctx context.Context, l *logger.Logger, clusterName string, sync bo
}
// Fetch public keys from gcloud to set up ssh access for all users into the
// shared ubuntu user.
authorizedKeys, err := gce.GetUserAuthorizedKeys()
authorizedKeys, err := gce.Infrastructure.GetUserAuthorizedKeys()
if err != nil {
return errors.Wrap(err, "failed to retrieve authorized keys from gcloud")
}
Expand Down Expand Up @@ -1132,7 +1132,7 @@ func urlGenerator(
) ([]string, error) {
var urls []string
for i, node := range nodes {
host := vm.Name(c.Name, int(node)) + "." + gce.DNSDomain()
host := vm.Name(c.Name, int(node)) + "." + gce.Infrastructure.DNSDomain()

// There are no DNS entries for local clusters.
if c.IsLocal() {
Expand Down
48 changes: 22 additions & 26 deletions pkg/roachprod/vm/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ import (
// ProviderName is aws.
const ProviderName = "aws"

// providerInstance is the instance to be registered into vm.Providers by Init.
var providerInstance = &Provider{}

//go:embed config.json
var configJson []byte

Expand All @@ -57,9 +54,8 @@ func Init() error {
"(https://docs.aws.amazon.com/cli/latest/userguide/installing.html)"
const noCredentials = "missing AWS credentials, expected ~/.aws/credentials file or AWS_ACCESS_KEY_ID env var"

configVal := awsConfigValue{awsConfig: *DefaultConfig}
providerInstance.Config = &configVal.awsConfig
providerInstance.IAMProfile = "roachprod-testing"
providerInstance := &Provider{}
providerInstance.Config.awsConfig = *DefaultConfig

haveRequiredVersion := func() bool {
// `aws --version` takes around 400ms on my machine.
Expand Down Expand Up @@ -232,6 +228,7 @@ func DefaultProviderOpts() *ProviderOpts {
RemoteUserName: "ubuntu",
DefaultEBSVolume: defaultEBSVolumeValue,
CreateRateLimit: 2,
IAMProfile: "roachprod-testing",
}
}

Expand All @@ -250,6 +247,10 @@ type ProviderOpts struct {
EBSVolumes ebsVolumeList
UseMultipleDisks bool

// IAMProfile designates the name of the instance profile to use for created
// EC2 instances if non-empty.
IAMProfile string

// Use specified ImageAMI when provisioning.
// Overrides config.json AMI.
ImageAMI string
Expand All @@ -274,11 +275,7 @@ type Provider struct {
Profile string

// Path to json for aws configuration, defaults to predefined configuration
Config *awsConfig

// IAMProfile designates the name of the instance profile to use for created
// EC2 instances if non-empty.
IAMProfile string
Config awsConfigValue

// aws accounts to perform action in, used by gcCmd only as it clean ups multiple aws accounts
AccountIDs []string
Expand Down Expand Up @@ -335,7 +332,8 @@ func (p *Provider) GetPreemptedSpotVMs(
//
// Sample error message:
//
// ‹An error occurred (InvalidInstanceID.NotFound) when calling the DescribeInstances operation: The instance IDs 'i-02e9adfac0e5fa18f, i-0bc7869fda0299caa' do not exist›
// ‹An error occurred (InvalidInstanceID.NotFound) when calling the DescribeInstances operation: The instance IDs 'i-02e9adfac0e5fa18f, i-0bc7869fda0299caa'
// do not exist›
func getInstanceIDsNotFound(errorMsg string) []string {
// Regular expression pattern to find instance IDs between single quotes
re := regexp.MustCompile(`'([^']*)'`)
Expand Down Expand Up @@ -512,26 +510,24 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) {
" created. Try lowering this limit when hitting 'Request limit exceeded' errors.")
flags.BoolVar(&o.UseSpot, ProviderName+"-use-spot",
false, "use AWS Spot VMs, which are significantly cheaper, but can be preempted by AWS.")
flags.StringVar(&providerInstance.IAMProfile, ProviderName+"-iam-profile", providerInstance.IAMProfile,
flags.StringVar(&o.IAMProfile, ProviderName+"-iam-profile", o.IAMProfile,
"the IAM instance profile to associate with created VMs if non-empty")
}

// ConfigureClusterCleanupFlags implements ProviderOpts.
func (p *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) {
flags.StringSliceVar(&p.AccountIDs, ProviderName+"-account-ids", []string{},
"AWS account ids as a comma-separated string")
}

// ConfigureClusterFlags implements vm.ProviderOpts.
func (o *ProviderOpts) ConfigureClusterFlags(flags *pflag.FlagSet, _ vm.MultipleProjectsOption) {
flags.StringVar(&providerInstance.Profile, ProviderName+"-profile", os.Getenv("AWS_PROFILE"),
// ConfigureProviderFlags is part of the vm.Provider interface.
func (p *Provider) ConfigureProviderFlags(flags *pflag.FlagSet, _ vm.MultipleProjectsOption) {
flags.StringVar(&p.Profile, ProviderName+"-profile", os.Getenv("AWS_PROFILE"),
"Profile to manage cluster in")
configFlagVal := awsConfigValue{awsConfig: *DefaultConfig}
flags.Var(&configFlagVal, ProviderName+"-config",
flags.Var(&p.Config, ProviderName+"-config",
"Path to json for aws configuration, defaults to predefined configuration")
}

// ConfigureClusterCleanupFlags implements ProviderOpts.
func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) {
flags.StringSliceVar(&providerInstance.AccountIDs, ProviderName+"-account-ids", []string{},
"AWS account ids as a comma-separated string")
}

// CleanSSH is part of vm.Provider. This implementation is a no-op,
// since we depend on the user's local identity file.
func (p *Provider) CleanSSH(l *logger.Logger) error {
Expand Down Expand Up @@ -1365,8 +1361,8 @@ func (p *Provider) runInstance(
args = append(args, "--cpu-options", cpuOptions)
}

if p.IAMProfile != "" {
args = append(args, "--iam-instance-profile", "Name="+p.IAMProfile)
if providerOpts.IAMProfile != "" {
args = append(args, "--iam-instance-profile", "Name="+providerOpts.IAMProfile)
}
ebsVolumes := assignEBSVolumes(&opts, providerOpts)
args, err = genDeviceMapping(ebsVolumes, args)
Expand Down
2 changes: 0 additions & 2 deletions pkg/roachprod/vm/aws/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,6 @@ func (c *awsConfigValue) Set(path string) (err error) {
if err != nil {
return err
}
// Update the provider's config with the user-specified config.
providerInstance.Config = &c.awsConfig
return nil
}

Expand Down
20 changes: 10 additions & 10 deletions pkg/roachprod/vm/azure/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ func (p *Provider) CreateProviderOpts() vm.ProviderOpts {
return DefaultProviderOpts()
}

// ConfigureProviderFlags implements vm.ProviderFlags and is a no-op.
func (p *Provider) ConfigureProviderFlags(*pflag.FlagSet, vm.MultipleProjectsOption) {
}

// ConfigureClusterCleanupFlags is part of ProviderOpts.
func (o *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) {
flags.StringSliceVar(&providerInstance.SubscriptionNames, ProviderName+"-subscription-names", []string{},
"Azure subscription names as a comma-separated string")
}

// ConfigureCreateFlags implements vm.ProviderFlags.
func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) {
flags.DurationVar(&providerInstance.OperationTimeout, ProviderName+"-timeout", providerInstance.OperationTimeout,
Expand All @@ -81,13 +91,3 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) {
flags.StringVar(&o.DiskCaching, ProviderName+"-disk-caching", "none",
"Disk caching behavior for attached storage. Valid values are: none, read-only, read-write. Not applicable to Ultra disks.")
}

// ConfigureClusterFlags implements vm.ProviderFlags and is a no-op.
func (o *ProviderOpts) ConfigureClusterFlags(*pflag.FlagSet, vm.MultipleProjectsOption) {
}

// ConfigureClusterCleanupFlags is part of ProviderOpts.
func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) {
flags.StringSliceVar(&providerInstance.SubscriptionNames, ProviderName+"-subscription-names", []string{},
"Azure subscription names as a comma-separated string")
}
1 change: 1 addition & 0 deletions pkg/roachprod/vm/flagstub/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ go_library(
"//pkg/roachprod/logger",
"//pkg/roachprod/vm",
"@com_github_cockroachdb_errors//:errors",
"@com_github_spf13_pflag//:pflag",
],
)
9 changes: 9 additions & 0 deletions pkg/roachprod/vm/flagstub/flagstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/cockroachdb/cockroach/pkg/roachprod/logger"
"github.com/cockroachdb/cockroach/pkg/roachprod/vm"
"github.com/cockroachdb/errors"
"github.com/spf13/pflag"
)

// New wraps a delegate vm.Provider to only return its name and
Expand All @@ -27,6 +28,14 @@ type provider struct {
unimplemented string
}

// ConfigureProviderFlags implements vm.Provider.
func (p *provider) ConfigureProviderFlags(*pflag.FlagSet, vm.MultipleProjectsOption) {
}

func (p *provider) ConfigureClusterCleanupFlags(*pflag.FlagSet) {

}

func (p *provider) SupportsSpotVMs() bool {
return false
}
Expand Down
1 change: 1 addition & 0 deletions pkg/roachprod/vm/gce/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ go_library(
"dns.go",
"fast_dns.go",
"gcloud.go",
"infra_provider.go",
"utils.go",
],
importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce",
Expand Down
Loading

0 comments on commit 38ec124

Please sign in to comment.