diff --git a/pkg/cmd/roachprod/cli/commands.go b/pkg/cmd/roachprod/cli/commands.go index 7082e5f9166b..267043fee715 100644 --- a/pkg/cmd/roachprod/cli/commands.go +++ b/pkg/cmd/roachprod/cli/commands.go @@ -1117,7 +1117,7 @@ func buildSSHKeysListCmd() *cobra.Command { Use: "list", Short: "list every SSH public key installed on clusters managed by roachprod", Run: wrap(func(cmd *cobra.Command, args []string) error { - authorizedKeys, err := gce.GetUserAuthorizedKeys() + authorizedKeys, err := gce.Infrastructure.GetUserAuthorizedKeys() if err != nil { return err } @@ -1168,7 +1168,7 @@ func buildSSHKeysRemoveCmd() *cobra.Command { Run: wrap(func(cmd *cobra.Command, args []string) error { user := args[0] - existingKeys, err := gce.GetUserAuthorizedKeys() + existingKeys, err := gce.Infrastructure.GetUserAuthorizedKeys() if err != nil { return fmt.Errorf("failed to fetch existing keys: %w", err) } diff --git a/pkg/cmd/roachprod/cli/flags.go b/pkg/cmd/roachprod/cli/flags.go index ca4b111ea74d..f3b94d8b60d8 100644 --- a/pkg/cmd/roachprod/cli/flags.go +++ b/pkg/cmd/roachprod/cli/flags.go @@ -157,11 +157,12 @@ func initCreateCmdFlags(createCmd *cobra.Command) { // Allow each Provider to inject additional configuration flags for _, providerName := range vm.AllProviderNames() { - if vm.Providers[providerName].Active() { + provider := vm.Providers[providerName] + if provider.Active() { providerOptsContainer[providerName].ConfigureCreateFlags(createCmd.Flags()) // createCmd only accepts a single GCE project, as opposed to all the other // commands. - providerOptsContainer[providerName].ConfigureClusterFlags(createCmd.Flags(), vm.SingleProject) + provider.ConfigureProviderFlags(createCmd.Flags(), vm.SingleProject) } } } @@ -170,7 +171,8 @@ func initClusterFlagsForMultiProjects( rootCmd *cobra.Command, excludeFromClusterFlagsMulti []*cobra.Command, ) { for _, providerName := range vm.AllProviderNames() { - if vm.Providers[providerName].Active() { + provider := vm.Providers[providerName] + if provider.Active() { for _, cmd := range rootCmd.Commands() { excludeCmd := false for _, c := range excludeFromClusterFlagsMulti { @@ -182,7 +184,7 @@ func initClusterFlagsForMultiProjects( if excludeCmd { continue } - providerOptsContainer[providerName].ConfigureClusterFlags(cmd.Flags(), vm.AcceptMultipleProjects) + provider.ConfigureProviderFlags(cmd.Flags(), vm.AcceptMultipleProjects) } } } @@ -387,10 +389,9 @@ func initGCCmdFlags(gcCmd *cobra.Command) { "dry-run", "n", dryrun, "dry run (don't perform any actions)") gcCmd.Flags().StringVar(&config.SlackToken, "slack-token", "", "Slack bot token") // Allow each Provider to inject additional configuration flags - for _, providerName := range vm.AllProviderNames() { - if vm.Providers[providerName].Active() { - // set up cluster cleanup flag for gcCmd - providerOptsContainer[providerName].ConfigureClusterCleanupFlags(gcCmd.Flags()) + for _, provider := range vm.Providers { + if provider.Active() { + provider.ConfigureClusterCleanupFlags(gcCmd.Flags()) } } } diff --git a/pkg/cmd/roachtest/roachtestflags/BUILD.bazel b/pkg/cmd/roachtest/roachtestflags/BUILD.bazel index d457d3f15eeb..f60b0d80b98c 100644 --- a/pkg/cmd/roachtest/roachtestflags/BUILD.bazel +++ b/pkg/cmd/roachtest/roachtestflags/BUILD.bazel @@ -10,6 +10,7 @@ go_library( visibility = ["//visibility:public"], deps = [ "//pkg/cmd/roachtest/spec", + "//pkg/roachprod/vm", "//pkg/util/randutil", "@com_github_cockroachdb_errors//:errors", "@com_github_spf13_pflag//:pflag", diff --git a/pkg/cmd/roachtest/roachtestflags/flags.go b/pkg/cmd/roachtest/roachtestflags/flags.go index e1f2f7101eb7..458892363276 100644 --- a/pkg/cmd/roachtest/roachtestflags/flags.go +++ b/pkg/cmd/roachtest/roachtestflags/flags.go @@ -11,6 +11,7 @@ import ( "time" "github.com/cockroachdb/cockroach/pkg/cmd/roachtest/spec" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/cockroach/pkg/util/randutil" "github.com/spf13/pflag" ) @@ -584,6 +585,9 @@ func AddListFlags(cmdFlags *pflag.FlagSet) { // command flag set. func AddRunFlags(cmdFlags *pflag.FlagSet) { globalMan.AddFlagsToCommand(runCmdID, cmdFlags) + for _, provider := range vm.Providers { + provider.ConfigureProviderFlags(cmdFlags, vm.SingleProject) + } } // AddRunOpsFlags adds all flags registered for the run-operations command to diff --git a/pkg/roachprod/roachprod.go b/pkg/roachprod/roachprod.go index 3badf53fb38e..45389007b312 100644 --- a/pkg/roachprod/roachprod.go +++ b/pkg/roachprod/roachprod.go @@ -335,7 +335,7 @@ func Sync(l *logger.Logger, options vm.ListOptions) (*cloud.Cloud, error) { if !config.Quiet { l.Printf("Refreshing DNS entries...") } - if err := gce.SyncDNS(l, vms); err != nil { + if err := gce.Infrastructure.SyncDNS(l, vms); err != nil { l.Errorf("failed to update DNS: %v", err) } } else { @@ -708,7 +708,7 @@ func SetupSSH(ctx context.Context, l *logger.Logger, clusterName string, sync bo } // Fetch public keys from gcloud to set up ssh access for all users into the // shared ubuntu user. - authorizedKeys, err := gce.GetUserAuthorizedKeys() + authorizedKeys, err := gce.Infrastructure.GetUserAuthorizedKeys() if err != nil { return errors.Wrap(err, "failed to retrieve authorized keys from gcloud") } @@ -1132,7 +1132,7 @@ func urlGenerator( ) ([]string, error) { var urls []string for i, node := range nodes { - host := vm.Name(c.Name, int(node)) + "." + gce.DNSDomain() + host := vm.Name(c.Name, int(node)) + "." + gce.Infrastructure.DNSDomain() // There are no DNS entries for local clusters. if c.IsLocal() { diff --git a/pkg/roachprod/vm/aws/aws.go b/pkg/roachprod/vm/aws/aws.go index c9945c3595fa..9cbfc221cf0e 100644 --- a/pkg/roachprod/vm/aws/aws.go +++ b/pkg/roachprod/vm/aws/aws.go @@ -35,9 +35,6 @@ import ( // ProviderName is aws. const ProviderName = "aws" -// providerInstance is the instance to be registered into vm.Providers by Init. -var providerInstance = &Provider{} - //go:embed config.json var configJson []byte @@ -57,9 +54,8 @@ func Init() error { "(https://docs.aws.amazon.com/cli/latest/userguide/installing.html)" const noCredentials = "missing AWS credentials, expected ~/.aws/credentials file or AWS_ACCESS_KEY_ID env var" - configVal := awsConfigValue{awsConfig: *DefaultConfig} - providerInstance.Config = &configVal.awsConfig - providerInstance.IAMProfile = "roachprod-testing" + providerInstance := &Provider{} + providerInstance.Config.awsConfig = *DefaultConfig haveRequiredVersion := func() bool { // `aws --version` takes around 400ms on my machine. @@ -232,6 +228,7 @@ func DefaultProviderOpts() *ProviderOpts { RemoteUserName: "ubuntu", DefaultEBSVolume: defaultEBSVolumeValue, CreateRateLimit: 2, + IAMProfile: "roachprod-testing", } } @@ -250,6 +247,10 @@ type ProviderOpts struct { EBSVolumes ebsVolumeList UseMultipleDisks bool + // IAMProfile designates the name of the instance profile to use for created + // EC2 instances if non-empty. + IAMProfile string + // Use specified ImageAMI when provisioning. // Overrides config.json AMI. ImageAMI string @@ -274,11 +275,7 @@ type Provider struct { Profile string // Path to json for aws configuration, defaults to predefined configuration - Config *awsConfig - - // IAMProfile designates the name of the instance profile to use for created - // EC2 instances if non-empty. - IAMProfile string + Config awsConfigValue // aws accounts to perform action in, used by gcCmd only as it clean ups multiple aws accounts AccountIDs []string @@ -335,7 +332,8 @@ func (p *Provider) GetPreemptedSpotVMs( // // Sample error message: // -// ‹An error occurred (InvalidInstanceID.NotFound) when calling the DescribeInstances operation: The instance IDs 'i-02e9adfac0e5fa18f, i-0bc7869fda0299caa' do not exist› +// ‹An error occurred (InvalidInstanceID.NotFound) when calling the DescribeInstances operation: The instance IDs 'i-02e9adfac0e5fa18f, i-0bc7869fda0299caa' +// do not exist› func getInstanceIDsNotFound(errorMsg string) []string { // Regular expression pattern to find instance IDs between single quotes re := regexp.MustCompile(`'([^']*)'`) @@ -512,26 +510,24 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { " created. Try lowering this limit when hitting 'Request limit exceeded' errors.") flags.BoolVar(&o.UseSpot, ProviderName+"-use-spot", false, "use AWS Spot VMs, which are significantly cheaper, but can be preempted by AWS.") - flags.StringVar(&providerInstance.IAMProfile, ProviderName+"-iam-profile", providerInstance.IAMProfile, + flags.StringVar(&o.IAMProfile, ProviderName+"-iam-profile", o.IAMProfile, "the IAM instance profile to associate with created VMs if non-empty") +} +// ConfigureClusterCleanupFlags implements ProviderOpts. +func (p *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { + flags.StringSliceVar(&p.AccountIDs, ProviderName+"-account-ids", []string{}, + "AWS account ids as a comma-separated string") } -// ConfigureClusterFlags implements vm.ProviderOpts. -func (o *ProviderOpts) ConfigureClusterFlags(flags *pflag.FlagSet, _ vm.MultipleProjectsOption) { - flags.StringVar(&providerInstance.Profile, ProviderName+"-profile", os.Getenv("AWS_PROFILE"), +// ConfigureProviderFlags is part of the vm.Provider interface. +func (p *Provider) ConfigureProviderFlags(flags *pflag.FlagSet, _ vm.MultipleProjectsOption) { + flags.StringVar(&p.Profile, ProviderName+"-profile", os.Getenv("AWS_PROFILE"), "Profile to manage cluster in") - configFlagVal := awsConfigValue{awsConfig: *DefaultConfig} - flags.Var(&configFlagVal, ProviderName+"-config", + flags.Var(&p.Config, ProviderName+"-config", "Path to json for aws configuration, defaults to predefined configuration") } -// ConfigureClusterCleanupFlags implements ProviderOpts. -func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { - flags.StringSliceVar(&providerInstance.AccountIDs, ProviderName+"-account-ids", []string{}, - "AWS account ids as a comma-separated string") -} - // CleanSSH is part of vm.Provider. This implementation is a no-op, // since we depend on the user's local identity file. func (p *Provider) CleanSSH(l *logger.Logger) error { @@ -1365,8 +1361,8 @@ func (p *Provider) runInstance( args = append(args, "--cpu-options", cpuOptions) } - if p.IAMProfile != "" { - args = append(args, "--iam-instance-profile", "Name="+p.IAMProfile) + if providerOpts.IAMProfile != "" { + args = append(args, "--iam-instance-profile", "Name="+providerOpts.IAMProfile) } ebsVolumes := assignEBSVolumes(&opts, providerOpts) args, err = genDeviceMapping(ebsVolumes, args) diff --git a/pkg/roachprod/vm/aws/config.go b/pkg/roachprod/vm/aws/config.go index ee470b49db40..a0da9353a766 100644 --- a/pkg/roachprod/vm/aws/config.go +++ b/pkg/roachprod/vm/aws/config.go @@ -174,8 +174,6 @@ func (c *awsConfigValue) Set(path string) (err error) { if err != nil { return err } - // Update the provider's config with the user-specified config. - providerInstance.Config = &c.awsConfig return nil } diff --git a/pkg/roachprod/vm/azure/flags.go b/pkg/roachprod/vm/azure/flags.go index 5b14e6c2b686..37322da49f5b 100644 --- a/pkg/roachprod/vm/azure/flags.go +++ b/pkg/roachprod/vm/azure/flags.go @@ -55,6 +55,16 @@ func (p *Provider) CreateProviderOpts() vm.ProviderOpts { return DefaultProviderOpts() } +// ConfigureProviderFlags implements vm.ProviderFlags and is a no-op. +func (p *Provider) ConfigureProviderFlags(*pflag.FlagSet, vm.MultipleProjectsOption) { +} + +// ConfigureClusterCleanupFlags is part of ProviderOpts. +func (o *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { + flags.StringSliceVar(&providerInstance.SubscriptionNames, ProviderName+"-subscription-names", []string{}, + "Azure subscription names as a comma-separated string") +} + // ConfigureCreateFlags implements vm.ProviderFlags. func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { flags.DurationVar(&providerInstance.OperationTimeout, ProviderName+"-timeout", providerInstance.OperationTimeout, @@ -81,13 +91,3 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { flags.StringVar(&o.DiskCaching, ProviderName+"-disk-caching", "none", "Disk caching behavior for attached storage. Valid values are: none, read-only, read-write. Not applicable to Ultra disks.") } - -// ConfigureClusterFlags implements vm.ProviderFlags and is a no-op. -func (o *ProviderOpts) ConfigureClusterFlags(*pflag.FlagSet, vm.MultipleProjectsOption) { -} - -// ConfigureClusterCleanupFlags is part of ProviderOpts. -func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { - flags.StringSliceVar(&providerInstance.SubscriptionNames, ProviderName+"-subscription-names", []string{}, - "Azure subscription names as a comma-separated string") -} diff --git a/pkg/roachprod/vm/flagstub/BUILD.bazel b/pkg/roachprod/vm/flagstub/BUILD.bazel index b123b806cd9c..55fd2594e760 100644 --- a/pkg/roachprod/vm/flagstub/BUILD.bazel +++ b/pkg/roachprod/vm/flagstub/BUILD.bazel @@ -9,5 +9,6 @@ go_library( "//pkg/roachprod/logger", "//pkg/roachprod/vm", "@com_github_cockroachdb_errors//:errors", + "@com_github_spf13_pflag//:pflag", ], ) diff --git a/pkg/roachprod/vm/flagstub/flagstub.go b/pkg/roachprod/vm/flagstub/flagstub.go index ba2461f460db..0b9588cfeb56 100644 --- a/pkg/roachprod/vm/flagstub/flagstub.go +++ b/pkg/roachprod/vm/flagstub/flagstub.go @@ -11,6 +11,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/roachprod/logger" "github.com/cockroachdb/cockroach/pkg/roachprod/vm" "github.com/cockroachdb/errors" + "github.com/spf13/pflag" ) // New wraps a delegate vm.Provider to only return its name and @@ -27,6 +28,14 @@ type provider struct { unimplemented string } +// ConfigureProviderFlags implements vm.Provider. +func (p *provider) ConfigureProviderFlags(*pflag.FlagSet, vm.MultipleProjectsOption) { +} + +func (p *provider) ConfigureClusterCleanupFlags(*pflag.FlagSet) { + +} + func (p *provider) SupportsSpotVMs() bool { return false } diff --git a/pkg/roachprod/vm/gce/BUILD.bazel b/pkg/roachprod/vm/gce/BUILD.bazel index 96b39f2e481d..6aa0ba4ef0c2 100644 --- a/pkg/roachprod/vm/gce/BUILD.bazel +++ b/pkg/roachprod/vm/gce/BUILD.bazel @@ -6,6 +6,7 @@ go_library( "dns.go", "fast_dns.go", "gcloud.go", + "infra_provider.go", "utils.go", ], importpath = "github.com/cockroachdb/cockroach/pkg/roachprod/vm/gce", diff --git a/pkg/roachprod/vm/gce/gcloud.go b/pkg/roachprod/vm/gce/gcloud.go index 92d1e9dbb992..b51c99011b2a 100644 --- a/pkg/roachprod/vm/gce/gcloud.go +++ b/pkg/roachprod/vm/gce/gcloud.go @@ -60,9 +60,6 @@ const ( VolumeTypePersistent VolumeType = "persistent" ) -// providerInstance is the instance to be registered into vm.Providers by Init. -var providerInstance = &Provider{} - var ( defaultDefaultProject, defaultMetadataProject, defaultDNSProject, defaultDefaultServiceAccount string // projects for which a cron GC job exists. @@ -114,13 +111,13 @@ func Init() error { initGCEProjectDefaults() initDNSDefault() + providerInstance := &Provider{} providerInstance.Projects = []string{defaultDefaultProject} projectFromEnv := os.Getenv("GCE_PROJECT") if projectFromEnv != "" { fmt.Printf("WARNING: `GCE_PROJECT` is deprecated; please, use `ROACHPROD_GCE_DEFAULT_PROJECT` instead") providerInstance.Projects = []string{projectFromEnv} } - providerInstance.ServiceAccount = os.Getenv("GCE_SERVICE_ACCOUNT") if _, err := exec.LookPath("gcloud"); err != nil { vm.Providers[ProviderName] = flagstub.New(&Provider{}, "please install the gcloud CLI utilities "+ "(https://cloud.google.com/sdk/downloads)") @@ -130,10 +127,11 @@ func Init() error { providerInstance.defaultProject = defaultDefaultProject providerInstance.metadataProject = defaultMetadataProject - providerInstance.defaultServiceAccount = defaultDefaultServiceAccount initialized = true vm.Providers[ProviderName] = providerInstance + Infrastructure = providerInstance + return nil } @@ -312,6 +310,9 @@ func DefaultProviderOpts() *ProviderOpts { TerminateOnMigration: false, UseSpot: false, preemptible: false, + + defaultServiceAccount: defaultDefaultServiceAccount, + ServiceAccount: os.Getenv("GCE_SERVICE_ACCOUNT"), } } @@ -351,13 +352,18 @@ type ProviderOpts struct { TerminateOnMigration bool // use preemptible instances preemptible bool + + ServiceAccount string + + // The service account to use if the default project is in use and no + // ServiceAccount was specified. + defaultServiceAccount string } // Provider is the GCE implementation of the vm.Provider interface. type Provider struct { *dnsProvider - Projects []string - ServiceAccount string + Projects []string // The project to use for looking up metadata. In particular, this includes // user keys. @@ -365,10 +371,6 @@ type Provider struct { // The project that provides the core roachprod services. defaultProject string - - // The service account to use if the default project is in use and no - // ServiceAccount was specified. - defaultServiceAccount string } // LogEntry represents a single log entry from the gcloud logging(stack driver) @@ -981,6 +983,7 @@ func (p *Provider) AttachVolume(l *logger.Logger, volume vm.Volume, vm *vm.VM) ( // (Provider.Projects). type ProjectsVal struct { AcceptMultipleProjects bool + Provider *Provider } // DefaultZones is the list of zones used by default for cluster creation. @@ -1023,7 +1026,7 @@ func (v ProjectsVal) Set(projects string) error { if !v.AcceptMultipleProjects && len(prj) > 1 { return fmt.Errorf("multiple GCE projects not supported for command") } - providerInstance.Projects = prj + v.Provider.Projects = prj return nil } @@ -1037,7 +1040,7 @@ func (v ProjectsVal) Type() string { // String is part of the pflag.Value interface. func (v ProjectsVal) String() string { - return strings.Join(providerInstance.Projects, ",") + return strings.Join(v.Provider.Projects, ",") } // GetProject returns the GCE project on which we're configured to operate. @@ -1065,9 +1068,9 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { flags.StringSliceVar(&o.ManagedSpotZones, ProviderName+"-managed-spot-zones", nil, "subset of zones in managed instance groups that will use spot instances") - flags.StringVar(&providerInstance.ServiceAccount, ProviderName+"-service-account", - providerInstance.ServiceAccount, "Service account to use") - flags.StringVar(&providerInstance.defaultServiceAccount, + flags.StringVar(&o.ServiceAccount, ProviderName+"-service-account", + o.ServiceAccount, "Service account to use") + flags.StringVar(&o.defaultServiceAccount, ProviderName+"-default-service-account", defaultDefaultServiceAccount, "Service account to use if the default project is in use and no "+ "--gce-service-account was specified") @@ -1110,8 +1113,8 @@ func (o *ProviderOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { false, "Enables the cron service (it is disabled by default)") } -// ConfigureClusterFlags implements vm.ProviderFlags. -func (o *ProviderOpts) ConfigureClusterFlags(flags *pflag.FlagSet, opt vm.MultipleProjectsOption) { +// ConfigureProviderFlags implements Provider +func (p *Provider) ConfigureProviderFlags(flags *pflag.FlagSet, opt vm.MultipleProjectsOption) { var usage string if opt == vm.SingleProject { usage = "GCE project to manage" @@ -1122,14 +1125,14 @@ func (o *ProviderOpts) ConfigureClusterFlags(flags *pflag.FlagSet, opt vm.Multip flags.Var( ProjectsVal{ AcceptMultipleProjects: opt == vm.AcceptMultipleProjects, + Provider: p, }, ProviderName+"-project", /* name */ usage) // Flags about DNS override the default values in - // providerInstance.dnsProvider. - - dnsProviderInstance := providerInstance.dnsProvider + // dnsProvider. + dnsProviderInstance := p.dnsProvider flags.StringVar( &dnsProviderInstance.dnsProject, ProviderName+"-dns-project", dnsProviderInstance.dnsProject, @@ -1161,16 +1164,15 @@ func (o *ProviderOpts) ConfigureClusterFlags(flags *pflag.FlagSet, opt vm.Multip ) // Flags about the GCE project to use override the defaults in - // providerInstance. - + // the provider. flags.StringVar( - &providerInstance.metadataProject, ProviderName+"-metadata-project", - providerInstance.metadataProject, + &p.metadataProject, ProviderName+"-metadata-project", + p.metadataProject, "google cloud project to use to store and fetch SSH keys", ) flags.StringVar( - &providerInstance.defaultProject, ProviderName+"-default-project", - providerInstance.defaultProject, + &p.defaultProject, ProviderName+"-default-project", + p.defaultProject, "google cloud project to use to run core roachprod services", ) } @@ -1181,7 +1183,7 @@ func (o *ProviderOpts) useArmAMI() bool { } // ConfigureClusterCleanupFlags is part of ProviderOpts. This implementation is a no-op. -func (o *ProviderOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { +func (p *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { } // CleanSSH TODO(peter): document @@ -1372,11 +1374,11 @@ func (p *Provider) computeInstanceArgs( "--boot-disk-type", "pd-ssd", } - if project == p.defaultProject && p.ServiceAccount == "" { - p.ServiceAccount = p.defaultServiceAccount + if project == p.defaultProject && providerOpts.ServiceAccount == "" { + providerOpts.ServiceAccount = providerOpts.defaultServiceAccount } - if p.ServiceAccount != "" { - args = append(args, "--service-account", p.ServiceAccount) + if providerOpts.ServiceAccount != "" { + args = append(args, "--service-account", providerOpts.ServiceAccount) } if providerOpts.preemptible { diff --git a/pkg/roachprod/vm/gce/infra_provider.go b/pkg/roachprod/vm/gce/infra_provider.go new file mode 100644 index 000000000000..d8e2244c3069 --- /dev/null +++ b/pkg/roachprod/vm/gce/infra_provider.go @@ -0,0 +1,27 @@ +// Copyright 2025 The Cockroach Authors. +// +// Use of this software is governed by the CockroachDB Software License +// included in the /LICENSE file. + +package gce + +import ( + "github.com/cockroachdb/cockroach/pkg/roachprod/logger" + "github.com/cockroachdb/cockroach/pkg/roachprod/vm" +) + +// InfraProvider is the API for GCP resources that are used as shared +// infrastructure by clusters on other other clouds. +type InfraProvider interface { + // GetUserAuthorizedKeys retrieves reads a list of user public keys from the + // gcloud cockroach-ephemeral project and returns them formatted for use in + // an authorized_keys file. + GetUserAuthorizedKeys() (AuthorizedKeys, error) + // SyncDNS replaces the configured DNS zone with the supplied hosts. + SyncDNS(l *logger.Logger, vms vm.List) error + // DNSDomain returns the configured DNS domain for public DNS A records. + DNSDomain() string +} + +// Infrastructure is the process level InfraProvider. +var Infrastructure InfraProvider diff --git a/pkg/roachprod/vm/gce/utils.go b/pkg/roachprod/vm/gce/utils.go index 2ce7ed19cc1e..50d493fa9232 100644 --- a/pkg/roachprod/vm/gce/utils.go +++ b/pkg/roachprod/vm/gce/utils.go @@ -348,14 +348,14 @@ func writeStartupScript( return tmpfile.Name(), nil } -// SyncDNS replaces the configured DNS zone with the supplied hosts. -func SyncDNS(l *logger.Logger, vms vm.List) error { - return providerInstance.dnsProvider.syncPublicDNS(l, vms) +// SyncDNS implements the InfraProvider interface. +func (p *Provider) SyncDNS(l *logger.Logger, vms vm.List) error { + return p.dnsProvider.syncPublicDNS(l, vms) } -// DNSDomain returns the configured DNS domain for public DNS A records. -func DNSDomain() string { - return providerInstance.dnsProvider.publicDomain +// DNSDomain implements the InfraProvider interface. +func (p *Provider) DNSDomain() string { + return p.dnsProvider.publicDomain } type AuthorizedKey struct { @@ -417,14 +417,12 @@ func (ak AuthorizedKeys) AsProjectMetadata() []byte { return buf.Bytes() } -// GetUserAuthorizedKeys retrieves reads a list of user public keys from the -// gcloud cockroach-ephemeral project and returns them formatted for use in -// an authorized_keys file. -func GetUserAuthorizedKeys() (AuthorizedKeys, error) { +// GetUserAuthorizedKeys implements the InfraProvider interface. +func (p *Provider) GetUserAuthorizedKeys() (AuthorizedKeys, error) { var outBuf bytes.Buffer // The below command will return a stream of user:pubkey as text. cmd := exec.Command("gcloud", "compute", "project-info", "describe", - "--project="+providerInstance.metadataProject, + "--project="+p.metadataProject, "--format=value(commonInstanceMetadata.ssh-keys)") cmd.Stderr = os.Stderr cmd.Stdout = &outBuf @@ -482,7 +480,7 @@ func GetUserAuthorizedKeys() (AuthorizedKeys, error) { // keys are stored in the project metadata for the roachprod's // `DefaultProject`. func AddUserAuthorizedKey(ak AuthorizedKey) error { - existingKeys, err := GetUserAuthorizedKeys() + existingKeys, err := Infrastructure.GetUserAuthorizedKeys() if err != nil { return err } diff --git a/pkg/roachprod/vm/local/local.go b/pkg/roachprod/vm/local/local.go index 61b677d276b3..3647425264b4 100644 --- a/pkg/roachprod/vm/local/local.go +++ b/pkg/roachprod/vm/local/local.go @@ -120,6 +120,9 @@ type Provider struct { vm.DNSProvider } +func (p *Provider) ConfigureProviderFlags(*pflag.FlagSet, vm.MultipleProjectsOption) { +} + func (p *Provider) SupportsSpotVMs() bool { return false } @@ -181,12 +184,8 @@ type providerOpts struct{} func (o *providerOpts) ConfigureCreateFlags(flags *pflag.FlagSet) { } -// ConfigureClusterFlags is part of ProviderOpts. This implementation is a no-op. -func (o *providerOpts) ConfigureClusterFlags(*pflag.FlagSet, vm.MultipleProjectsOption) { -} - -// ConfigureClusterCleanupFlags is part of ProviderOpts. This implementation is a no-op. -func (o *providerOpts) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { +// ConfigureClusterCleanupFlags is part of the vm.Provider interface. This implementation is a no-op. +func (p *Provider) ConfigureClusterCleanupFlags(flags *pflag.FlagSet) { } // CleanSSH is part of the vm.Provider interface. This implementation is a no-op. diff --git a/pkg/roachprod/vm/vm.go b/pkg/roachprod/vm/vm.go index 3dfd77ebe658..0f4a2b84eb93 100644 --- a/pkg/roachprod/vm/vm.go +++ b/pkg/roachprod/vm/vm.go @@ -353,13 +353,6 @@ type ProviderOpts interface { // ConfigureCreateFlags configures a FlagSet with any options relevant to the // `create` command. ConfigureCreateFlags(*pflag.FlagSet) - // ConfigureClusterFlags configures a FlagSet with any options relevant to - // cluster manipulation commands (`create`, `destroy`, `list`, `sync` and - // `gc`). - ConfigureClusterFlags(*pflag.FlagSet, MultipleProjectsOption) - // ConfigureClusterCleanupFlags configures a FlagSet with any options relevant to - // commands (`gc`) - ConfigureClusterCleanupFlags(*pflag.FlagSet) } // VolumeSnapshot is an abstract representation of a specific volume snapshot. @@ -468,6 +461,14 @@ type ServiceAddress struct { // A Provider is a source of virtual machines running on some hosting platform. type Provider interface { + // ConfigureProviderFlags is used to specify flags that apply to the provider + // instance and should be used for all clusters managed by the provider. + ConfigureProviderFlags(*pflag.FlagSet, MultipleProjectsOption) + + // ConfigureClusterCleanupFlags configures a FlagSet with any options + // relevant to commands (`gc`) + ConfigureClusterCleanupFlags(*pflag.FlagSet) + CreateProviderOpts() ProviderOpts CleanSSH(l *logger.Logger) error diff --git a/pkg/sql/logictest/testdata/logic_test/pg_builtins b/pkg/sql/logictest/testdata/logic_test/pg_builtins index abe69b1e3225..75a31c5155a7 100644 --- a/pkg/sql/logictest/testdata/logic_test/pg_builtins +++ b/pkg/sql/logictest/testdata/logic_test/pg_builtins @@ -709,6 +709,16 @@ SELECT to_regclass('4294967230') ---- NULL +query T +SELECT to_regclass('0 ') +---- +NULL + +query T +SELECT to_regclass(' -123 ') +---- +NULL + query T SELECT to_regclass('pg_policy') ---- @@ -739,6 +749,16 @@ SELECT to_regnamespace(' 1330834471') ---- NULL +query T +SELECT to_regnamespace('0 ') +---- +NULL + +query T +SELECT to_regnamespace('-1234 ') +---- +NULL + query T SELECT to_regproc('_st_contains') ---- @@ -779,6 +799,16 @@ SELECT to_regprocedure('961893967') ---- NULL +query T +SELECT to_regprocedure('0') +---- +NULL + +query T +SELECT to_regprocedure('-2') +---- +NULL + query T SELECT to_regrole('admin') ---- @@ -799,6 +829,16 @@ SELECT to_regrole('1546506610') ---- NULL +query T +SELECT to_regrole('0') +---- +NULL + +query T +SELECT to_regrole('-2') +---- +NULL + query T SELECT to_regtype('interval') ---- @@ -824,6 +864,16 @@ SELECT to_regtype('1186') ---- NULL +query T +SELECT to_regtype('0') +---- +NULL + +query T +SELECT to_regtype('-3') +---- +NULL + query T SELECT to_regtype('test_type') ---- diff --git a/pkg/sql/sem/builtins/pg_builtins.go b/pkg/sql/sem/builtins/pg_builtins.go index d327a0c770e2..1354b23115ed 100644 --- a/pkg/sql/sem/builtins/pg_builtins.go +++ b/pkg/sql/sem/builtins/pg_builtins.go @@ -573,8 +573,9 @@ func makeToRegOverload(typ *types.T, helpText string) builtinDefinition { ReturnType: tree.FixedReturnType(types.RegType), Fn: func(ctx context.Context, evalCtx *eval.Context, args tree.Datums) (tree.Datum, error) { typName := tree.MustBeDString(args[0]) - int, _ := strconv.Atoi(strings.TrimSpace(string(typName))) - if int > 0 { + _, err := strconv.Atoi(strings.TrimSpace(string(typName))) + if err == nil { + // If a number was passed in, return NULL. return tree.DNull, nil } typOid, err := eval.ParseDOid(ctx, evalCtx, string(typName), typ)