From 910229a9be126cf40c239af683b170e29e242e04 Mon Sep 17 00:00:00 2001 From: Laimonas Rastenis Date: Tue, 26 Nov 2024 12:30:13 +0200 Subject: [PATCH] feat: anywhere provider --- cmd/dump/run.go | 2 +- internal/castai/types.go | 6 + internal/config/config.go | 5 + .../services/controller/mock/workqueue.go | 125 ++++++++++++++++++ internal/services/discovery/discovery.go | 64 +++++++-- internal/services/discovery/discovery_test.go | 4 +- internal/services/discovery/mock/discovery.go | 45 ++++--- .../services/providers/anywhere/anywhere.go | 79 +++++++++++ .../providers/anywhere/anywhere_test.go | 116 ++++++++++++++++ .../providers/anywhere/client/client.go | 65 +++++++++ .../providers/anywhere/client/client_test.go | 121 +++++++++++++++++ .../providers/anywhere/client/mock/client.go | 50 +++++++ internal/services/providers/kops/kops.go | 2 +- .../services/providers/openshift/openshift.go | 2 +- .../providers/openshift/openshift_test.go | 4 +- internal/services/providers/providers.go | 9 ++ internal/services/providers/providers_test.go | 17 +++ 17 files changed, 682 insertions(+), 34 deletions(-) create mode 100644 internal/services/controller/mock/workqueue.go create mode 100644 internal/services/providers/anywhere/anywhere.go create mode 100644 internal/services/providers/anywhere/anywhere_test.go create mode 100644 internal/services/providers/anywhere/client/client.go create mode 100644 internal/services/providers/anywhere/client/client_test.go create mode 100644 internal/services/providers/anywhere/client/mock/client.go diff --git a/cmd/dump/run.go b/cmd/dump/run.go index e485f7ef..7594a09f 100644 --- a/cmd/dump/run.go +++ b/cmd/dump/run.go @@ -52,7 +52,7 @@ func run(ctx context.Context) error { if cfg.Static != nil && cfg.Static.ClusterID != "" { clusterID = cfg.Static.ClusterID } else { - c, err := discoveryService.GetClusterID(ctx) + c, err := discoveryService.GetKubeSystemNamespaceID(ctx) if err != nil { return fmt.Errorf("getting cluster ID: %w", err) } diff --git a/internal/castai/types.go b/internal/castai/types.go index f79bffb0..48e35228 100644 --- a/internal/castai/types.go +++ b/internal/castai/types.go @@ -41,6 +41,11 @@ type OpenshiftParams struct { InternalID string `json:"internalId"` } +type AnywhereParams struct { + ClusterName string `json:"clusterName"` + KubeSystemNamespaceID uuid.UUID `json:"kubeSystemNamespaceId"` +} + type RegisterClusterRequest struct { ID uuid.UUID `json:"id"` Name string `json:"name"` @@ -49,6 +54,7 @@ type RegisterClusterRequest struct { KOPS *KOPSParams `json:"kops"` AKS *AKSParams `json:"aks"` Openshift *OpenshiftParams `json:"openshift"` + Anywhere *AnywhereParams `json:"anywhere"` } type Cluster struct { diff --git a/internal/config/config.go b/internal/config/config.go index 1ec8133a..1347a2e4 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -27,6 +27,7 @@ type Config struct { KOPS *KOPS `mapstructure:"kops"` AKS *AKS `mapstructure:"aks"` OpenShift *OpenShift `mapstructure:"openshift"` + Anywhere *Anywhere `mapstructure:"anywhere"` Static *Static `mapstructure:"static"` Controller *Controller `mapstructure:"controller"` @@ -109,6 +110,10 @@ type Static struct { ClusterID string `mapstructure:"cluster_id"` } +type Anywhere struct { + ClusterName string `mapstructure:"cluster_name"` +} + type Controller struct { Interval time.Duration `mapstructure:"interval"` MemoryPressureInterval time.Duration `mapstructure:"memory_pressure_interval"` diff --git a/internal/services/controller/mock/workqueue.go b/internal/services/controller/mock/workqueue.go new file mode 100644 index 00000000..849c2540 --- /dev/null +++ b/internal/services/controller/mock/workqueue.go @@ -0,0 +1,125 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: k8s.io/client-go/util/workqueue (interfaces: Interface) + +// Package mock_workqueue is a generated GoMock package. +package mock_workqueue + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockInterface is a mock of Interface interface. +type MockInterface struct { + ctrl *gomock.Controller + recorder *MockInterfaceMockRecorder +} + +// MockInterfaceMockRecorder is the mock recorder for MockInterface. +type MockInterfaceMockRecorder struct { + mock *MockInterface +} + +// NewMockInterface creates a new mock instance. +func NewMockInterface(ctrl *gomock.Controller) *MockInterface { + mock := &MockInterface{ctrl: ctrl} + mock.recorder = &MockInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockInterface) Add(arg0 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Add", arg0) +} + +// Add indicates an expected call of Add. +func (mr *MockInterfaceMockRecorder) Add(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockInterface)(nil).Add), arg0) +} + +// Done mocks base method. +func (m *MockInterface) Done(arg0 interface{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Done", arg0) +} + +// Done indicates an expected call of Done. +func (mr *MockInterfaceMockRecorder) Done(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Done", reflect.TypeOf((*MockInterface)(nil).Done), arg0) +} + +// Get mocks base method. +func (m *MockInterface) Get() (interface{}, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get") + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockInterfaceMockRecorder) Get() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockInterface)(nil).Get)) +} + +// Len mocks base method. +func (m *MockInterface) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockInterfaceMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockInterface)(nil).Len)) +} + +// ShutDown mocks base method. +func (m *MockInterface) ShutDown() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ShutDown") +} + +// ShutDown indicates an expected call of ShutDown. +func (mr *MockInterfaceMockRecorder) ShutDown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutDown", reflect.TypeOf((*MockInterface)(nil).ShutDown)) +} + +// ShutDownWithDrain mocks base method. +func (m *MockInterface) ShutDownWithDrain() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ShutDownWithDrain") +} + +// ShutDownWithDrain indicates an expected call of ShutDownWithDrain. +func (mr *MockInterfaceMockRecorder) ShutDownWithDrain() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShutDownWithDrain", reflect.TypeOf((*MockInterface)(nil).ShutDownWithDrain)) +} + +// ShuttingDown mocks base method. +func (m *MockInterface) ShuttingDown() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ShuttingDown") + ret0, _ := ret[0].(bool) + return ret0 +} + +// ShuttingDown indicates an expected call of ShuttingDown. +func (mr *MockInterfaceMockRecorder) ShuttingDown() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ShuttingDown", reflect.TypeOf((*MockInterface)(nil).ShuttingDown)) +} diff --git a/internal/services/discovery/discovery.go b/internal/services/discovery/discovery.go index c24af7ec..c915e2e1 100644 --- a/internal/services/discovery/discovery.go +++ b/internal/services/discovery/discovery.go @@ -18,13 +18,17 @@ import ( ) type Service interface { + // GetCSP discovers the cluster cloud service provider (CSP) by listing the cluster nodes and inspecting their labels. + // CSP is retrieved by parsing the Node.Spec.ProviderID property. + GetCSP(ctx context.Context) (csp cloud.Cloud, reterr error) + // GetCSPAndRegion discovers the cluster cloud service provider (CSP) and the region the cluster is deployed in by // listing the cluster nodes and inspecting their labels. CSP is retrieved by parsing the Node.Spec.ProviderID property. // Whereas the region is read from the well-known node region labels. GetCSPAndRegion(ctx context.Context) (csp cloud.Cloud, region string, reterr error) - // GetClusterID retrieves the cluster ID by reading the UID of the kube-system namespace. - GetClusterID(ctx context.Context) (*uuid.UUID, error) + // GetKubeSystemNamespaceID retrieves the UID of the kube-system namespace. + GetKubeSystemNamespaceID(ctx context.Context) (*uuid.UUID, error) // GetKOPSClusterNameAndStateStore discovers the cluster name and kOps state store bucket from the kube-system namespace // annotation. kOps annotates the kube-system namespace with annotations such as this: @@ -57,22 +61,26 @@ func New(clientset kubernetes.Interface, dyno dynamic.Interface) *ServiceImpl { } } +func (s *ServiceImpl) GetCSP(ctx context.Context) (cloud.Cloud, error) { + return s.getCSP(ctx, "") +} + func (s *ServiceImpl) GetCSPAndRegion(ctx context.Context) (csp cloud.Cloud, region string, reterr error) { return s.getCSPAndRegion(ctx, "") } -func (s *ServiceImpl) GetClusterID(ctx context.Context) (*uuid.UUID, error) { +func (s *ServiceImpl) GetKubeSystemNamespaceID(ctx context.Context) (*uuid.UUID, error) { ns, err := s.getKubeSystemNamespace(ctx) if err != nil { return nil, err } - clusterID, err := uuid.Parse(string(ns.UID)) + namespaceID, err := uuid.Parse(string(ns.UID)) if err != nil { return nil, fmt.Errorf("parsing namespace %q uid: %w", metav1.NamespaceSystem, err) } - return &clusterID, nil + return &namespaceID, nil } func (s *ServiceImpl) getKubeSystemNamespace(ctx context.Context) (*v1.Namespace, error) { @@ -93,10 +101,35 @@ func (s *ServiceImpl) getKubeSystemNamespace(ctx context.Context) (*v1.Namespace return ns, nil } +func (s *ServiceImpl) getCSP(ctx context.Context, next string) (cloud.Cloud, error) { + nodes, err := s.listNodes(ctx, next) + if err != nil { + return "", err + } + + for i := range nodes.Items { + node := &nodes.Items[i] + + if !isNodeReady(node) { + continue + } + + if nodeCSP, ok := getNodeCSP(node); ok { + return nodeCSP, nil + } + } + + if nodes.Continue != "" { + return s.getCSP(ctx, nodes.Continue) + } + + return "", fmt.Errorf("failed to discover csp") +} + func (s *ServiceImpl) getCSPAndRegion(ctx context.Context, next string) (csp cloud.Cloud, region string, reterr error) { - nodes, err := s.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{Limit: 10, Continue: next}) + nodes, err := s.listNodes(ctx, next) if err != nil { - return "", "", fmt.Errorf("listing nodes: %w", err) + return "", "", err } for i := range nodes.Items { @@ -106,13 +139,11 @@ func (s *ServiceImpl) getCSPAndRegion(ctx context.Context, next string) (csp clo continue } - nodeCSP, ok := getCSP(node) - if ok { + if nodeCSP, ok := getNodeCSP(node); ok { csp = nodeCSP } - nodeRegion, ok := getRegion(node) - if ok { + if nodeRegion, ok := getRegion(node); ok { region = nodeRegion } @@ -128,6 +159,15 @@ func (s *ServiceImpl) getCSPAndRegion(ctx context.Context, next string) (csp clo return "", "", fmt.Errorf("failed discovering properties: csp=%q, region=%q", csp, region) } +func (s *ServiceImpl) listNodes(ctx context.Context, next string) (*v1.NodeList, error) { + nodes, err := s.clientset.CoreV1().Nodes().List(ctx, metav1.ListOptions{Limit: 10, Continue: next}) + if err != nil { + return nil, fmt.Errorf("listing nodes: %w", err) + } + + return nodes, nil +} + func isNodeReady(n *v1.Node) bool { for _, cond := range n.Status.Conditions { if cond.Type == v1.NodeReady && cond.Status == v1.ConditionTrue { @@ -150,7 +190,7 @@ func getRegion(n *v1.Node) (string, bool) { return "", false } -func getCSP(n *v1.Node) (cloud.Cloud, bool) { +func getNodeCSP(n *v1.Node) (cloud.Cloud, bool) { providerID := n.Spec.ProviderID if strings.HasPrefix(providerID, "gce://") { diff --git a/internal/services/discovery/discovery_test.go b/internal/services/discovery/discovery_test.go index 6bcbffd5..e1183dca 100644 --- a/internal/services/discovery/discovery_test.go +++ b/internal/services/discovery/discovery_test.go @@ -17,7 +17,7 @@ import ( "castai-agent/pkg/cloud" ) -func TestServiceImpl_GetClusterID(t *testing.T) { +func TestServiceImpl_GetKubeSystemNamespaceID(t *testing.T) { namespaceID := uuid.New() objects := []runtime.Object{ &v1.Namespace{ @@ -33,7 +33,7 @@ func TestServiceImpl_GetClusterID(t *testing.T) { s := New(clientset, dyno) - id, err := s.GetClusterID(context.Background()) + id, err := s.GetKubeSystemNamespaceID(context.Background()) require.NoError(t, err) require.Equal(t, namespaceID, *id) diff --git a/internal/services/discovery/mock/discovery.go b/internal/services/discovery/mock/discovery.go index 15d56983..1f408900 100644 --- a/internal/services/discovery/mock/discovery.go +++ b/internal/services/discovery/mock/discovery.go @@ -37,6 +37,21 @@ func (m *MockService) EXPECT() *MockServiceMockRecorder { return m.recorder } +// GetCSP mocks base method. +func (m *MockService) GetCSP(ctx context.Context) (cloud.Cloud, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetCSP", ctx) + ret0, _ := ret[0].(cloud.Cloud) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetCSP indicates an expected call of GetCSP. +func (mr *MockServiceMockRecorder) GetCSP(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCSP", reflect.TypeOf((*MockService)(nil).GetCSP), ctx) +} + // GetCSPAndRegion mocks base method. func (m *MockService) GetCSPAndRegion(ctx context.Context) (cloud.Cloud, string, error) { m.ctrl.T.Helper() @@ -53,21 +68,6 @@ func (mr *MockServiceMockRecorder) GetCSPAndRegion(ctx interface{}) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCSPAndRegion", reflect.TypeOf((*MockService)(nil).GetCSPAndRegion), ctx) } -// GetClusterID mocks base method. -func (m *MockService) GetClusterID(ctx context.Context) (*uuid.UUID, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetClusterID", ctx) - ret0, _ := ret[0].(*uuid.UUID) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetClusterID indicates an expected call of GetClusterID. -func (mr *MockServiceMockRecorder) GetClusterID(ctx interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterID", reflect.TypeOf((*MockService)(nil).GetClusterID), ctx) -} - // GetKOPSClusterNameAndStateStore mocks base method. func (m *MockService) GetKOPSClusterNameAndStateStore(ctx context.Context, log logrus.FieldLogger) (string, string, error) { m.ctrl.T.Helper() @@ -84,6 +84,21 @@ func (mr *MockServiceMockRecorder) GetKOPSClusterNameAndStateStore(ctx, log inte return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKOPSClusterNameAndStateStore", reflect.TypeOf((*MockService)(nil).GetKOPSClusterNameAndStateStore), ctx, log) } +// GetKubeSystemNamespaceID mocks base method. +func (m *MockService) GetKubeSystemNamespaceID(ctx context.Context) (*uuid.UUID, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetKubeSystemNamespaceID", ctx) + ret0, _ := ret[0].(*uuid.UUID) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetKubeSystemNamespaceID indicates an expected call of GetKubeSystemNamespaceID. +func (mr *MockServiceMockRecorder) GetKubeSystemNamespaceID(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetKubeSystemNamespaceID", reflect.TypeOf((*MockService)(nil).GetKubeSystemNamespaceID), ctx) +} + // GetOpenshiftClusterID mocks base method. func (m *MockService) GetOpenshiftClusterID(ctx context.Context) (string, error) { m.ctrl.T.Helper() diff --git a/internal/services/providers/anywhere/anywhere.go b/internal/services/providers/anywhere/anywhere.go new file mode 100644 index 00000000..ad6cd58a --- /dev/null +++ b/internal/services/providers/anywhere/anywhere.go @@ -0,0 +1,79 @@ +package anywhere + +import ( + "context" + "fmt" + + "github.com/sirupsen/logrus" + v1 "k8s.io/api/core/v1" + + "castai-agent/internal/castai" + "castai-agent/internal/config" + "castai-agent/internal/services/discovery" + "castai-agent/internal/services/providers/anywhere/client" + "castai-agent/internal/services/providers/types" +) + +const Name = "anywhere" + +var _ types.Provider = (*Provider)(nil) + +type Provider struct { + discoveryService discovery.Service + client client.Client + log logrus.FieldLogger +} + +func New(discoveryService discovery.Service, client client.Client, log logrus.FieldLogger) *Provider { + return &Provider{ + discoveryService: discoveryService, + client: client, + log: log, + } +} + +func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { + kubeSystemNamespaceId, err := p.discoveryService.GetKubeSystemNamespaceID(ctx) + if err != nil { + return nil, fmt.Errorf("getting kube-system namespace id: %w", err) + } + + clusterName := "" + if cfg := config.Get().Anywhere; cfg != nil { + clusterName = cfg.ClusterName + } + + if clusterName == "" { + discoveredClusterName, err := p.client.GetClusterName(ctx) + if err != nil { + p.log.Errorf("discovering cluster name: %v", err) + } else if discoveredClusterName != "" { + clusterName = discoveredClusterName + } + } + + req := &castai.RegisterClusterRequest{ + Name: clusterName, + Anywhere: &castai.AnywhereParams{ + ClusterName: clusterName, + KubeSystemNamespaceID: *kubeSystemNamespaceId, + }, + } + resp, err := client.RegisterCluster(ctx, req) + if err != nil { + return nil, err + } + + return &types.ClusterRegistration{ + ClusterID: resp.ID, + OrganizationID: resp.OrganizationID, + }, nil +} + +func (p *Provider) Name() string { + return Name +} + +func (p *Provider) FilterSpot(_ context.Context, _ []*v1.Node) ([]*v1.Node, error) { + return nil, nil +} diff --git a/internal/services/providers/anywhere/anywhere_test.go b/internal/services/providers/anywhere/anywhere_test.go new file mode 100644 index 00000000..67cc0de0 --- /dev/null +++ b/internal/services/providers/anywhere/anywhere_test.go @@ -0,0 +1,116 @@ +package anywhere + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/uuid" + "github.com/samber/lo" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "castai-agent/internal/castai" + mock_castai "castai-agent/internal/castai/mock" + "castai-agent/internal/config" + discovery_mock "castai-agent/internal/services/discovery/mock" + anywhere_client_mock "castai-agent/internal/services/providers/anywhere/client/mock" +) + +func Test_RegisterCluster(t *testing.T) { + clusterID := uuid.New() + orgID := uuid.New() + kubeNsId := uuid.New() + + tests := map[string]struct { + expectedClusterName string + expectedKubeNamespaceId uuid.UUID + expectedErr *string + setup func(*require.Assertions, *gomock.Controller) (*discovery_mock.MockService, *anywhere_client_mock.MockClient) + }{ + "should fail cluster registration when there's an error retrieving kube-system namespace id": { + setup: func(_ *require.Assertions, ctrl *gomock.Controller) (*discovery_mock.MockService, *anywhere_client_mock.MockClient) { + discoveryService := discovery_mock.NewMockService(ctrl) + anywhereClient := anywhere_client_mock.NewMockClient(ctrl) + + discoveryService.EXPECT().GetKubeSystemNamespaceID(gomock.Any()).Return(nil, fmt.Errorf("some error")).Times(1) + + return discoveryService, anywhereClient + }, + expectedErr: lo.ToPtr("getting kube-system namespace id"), + }, + "should use cluster name from config when it's provided through env variables": { + setup: func(r *require.Assertions, ctrl *gomock.Controller) (*discovery_mock.MockService, *anywhere_client_mock.MockClient) { + r.NoError(os.Setenv("ANYWHERE_CLUSTER_NAME", "env-cluster-name")) + + discoveryService := discovery_mock.NewMockService(ctrl) + anywhereClient := anywhere_client_mock.NewMockClient(ctrl) + + discoveryService.EXPECT().GetKubeSystemNamespaceID(gomock.Any()).Return(&kubeNsId, nil).Times(1) + anywhereClient.EXPECT().GetClusterName(gomock.Any()).Times(0) + + return discoveryService, anywhereClient + }, + expectedClusterName: "env-cluster-name", + expectedKubeNamespaceId: kubeNsId, + }, + "should use cluster name from client when it's provided through env variables": { + setup: func(_ *require.Assertions, ctrl *gomock.Controller) (*discovery_mock.MockService, *anywhere_client_mock.MockClient) { + discoveryService := discovery_mock.NewMockService(ctrl) + anywhereClient := anywhere_client_mock.NewMockClient(ctrl) + + discoveryService.EXPECT().GetKubeSystemNamespaceID(gomock.Any()).Return(&kubeNsId, nil).Times(1) + anywhereClient.EXPECT().GetClusterName(gomock.Any()).Return("client-cluster-name", nil).Times(1) + + return discoveryService, anywhereClient + }, + expectedClusterName: "client-cluster-name", + expectedKubeNamespaceId: kubeNsId, + }, + } + + for name, test := range tests { + test := test + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + t.Cleanup(config.Reset) + t.Cleanup(os.Clearenv) + + ctx := context.Background() + r := require.New(t) + + discoveryService, anywhereClient := test.setup(r, ctrl) + castaiClient := mock_castai.NewMockClient(ctrl) + + provider := New(discoveryService, anywhereClient, logrus.New()) + + if test.expectedErr == nil { + castaiClient.EXPECT().RegisterCluster(gomock.Any(), &castai.RegisterClusterRequest{ + Name: test.expectedClusterName, + Anywhere: &castai.AnywhereParams{ + ClusterName: test.expectedClusterName, + KubeSystemNamespaceID: test.expectedKubeNamespaceId, + }, + }).Return(&castai.RegisterClusterResponse{ + Cluster: castai.Cluster{ + ID: clusterID.String(), + OrganizationID: orgID.String(), + }, + }, nil).Times(1) + } + + result, err := provider.RegisterCluster(ctx, castaiClient) + + if test.expectedErr != nil { + r.ErrorContains(err, *test.expectedErr) + } else { + r.Equal(clusterID.String(), result.ClusterID) + r.Equal(orgID.String(), result.OrganizationID) + } + }) + } +} diff --git a/internal/services/providers/anywhere/client/client.go b/internal/services/providers/anywhere/client/client.go new file mode 100644 index 00000000..2cdab778 --- /dev/null +++ b/internal/services/providers/anywhere/client/client.go @@ -0,0 +1,65 @@ +//go:generate mockgen -destination ./mock/client.go . Client +package client + +import ( + "context" + "fmt" + + "github.com/samber/lo" + "github.com/sirupsen/logrus" + + "castai-agent/internal/services/discovery" + eks_client "castai-agent/internal/services/providers/eks/client" + gke_client "castai-agent/internal/services/providers/gke/client" + "castai-agent/pkg/cloud" +) + +type Client interface { + // GetClusterName attempts to discover the name of the cluster. + GetClusterName(ctx context.Context) (string, error) +} + +type client struct { + log logrus.FieldLogger + discoveryService discovery.Service + eksClient eks_client.Client + gkeMetadataClient gke_client.Metadata +} + +func New(log logrus.FieldLogger, discoveryService discovery.Service) Client { + return &client{ + log: log, + discoveryService: discoveryService, + } +} + +func (c *client) GetClusterName(ctx context.Context) (string, error) { + csp, _ := c.discoveryService.GetCSP(ctx) + if csp == cloud.AWS { + if c.eksClient == nil { + client, err := eks_client.New(ctx, c.log, eks_client.WithEC2Client(), eks_client.WithMetadataDiscovery()) + + if err != nil { + return "", err + } + + c.eksClient = client + } + + if c.eksClient != nil { + clusterName, err := c.eksClient.GetClusterName(ctx) + + return lo.FromPtrOr(clusterName, ""), err + } + } else if csp == cloud.GCP { + if c.gkeMetadataClient == nil { + c.gkeMetadataClient = gke_client.NewMetadataClient() + } + + if c.gkeMetadataClient != nil { + return c.gkeMetadataClient.GetClusterName() + } + } + + return "", fmt.Errorf("cluster name could not be determined automatically") +} diff --git a/internal/services/providers/anywhere/client/client_test.go b/internal/services/providers/anywhere/client/client_test.go new file mode 100644 index 00000000..1f82093f --- /dev/null +++ b/internal/services/providers/anywhere/client/client_test.go @@ -0,0 +1,121 @@ +package client + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/samber/lo" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + + "castai-agent/internal/services/discovery" + discovery_mock "castai-agent/internal/services/discovery/mock" + eks_client_mock "castai-agent/internal/services/providers/eks/client/mock" + gke_client_mock "castai-agent/internal/services/providers/gke/client/mock" + "castai-agent/pkg/cloud" +) + +func Test_GetClusterName(t *testing.T) { + + tests := map[string]struct { + csp cloud.Cloud + eksClusterName *string + gkeClusterName string + expectedClusterName string + expectedErr *string + getClient func(context.Context, *gomock.Controller, logrus.FieldLogger, discovery.Service) *client + }{ + "should fail to determine cluster name when CSP is not determined": { + getClient: func(_ context.Context, _ *gomock.Controller, log logrus.FieldLogger, discoveryService discovery.Service) *client { + return &client{ + log: logrus.New(), + discoveryService: discoveryService, + } + }, + expectedErr: lo.ToPtr("cluster name could not be determined automatically"), + }, + "should use EKS client to determine cluster name when CSP is AWS": { + csp: cloud.AWS, + getClient: func(ctx context.Context, ctrl *gomock.Controller, log logrus.FieldLogger, discoveryService discovery.Service) *client { + eksClient := eks_client_mock.NewMockClient(ctrl) + eksClient.EXPECT().GetClusterName(ctx).Return(lo.ToPtr("eks-cluster-name"), nil).Times(1) + + return &client{ + log: logrus.New(), + discoveryService: discoveryService, + eksClient: eksClient, + } + }, + expectedClusterName: "eks-cluster-name", + }, + "should return EKS client's error when CSP is AWS and underlying client returns an error": { + csp: cloud.AWS, + getClient: func(ctx context.Context, ctrl *gomock.Controller, log logrus.FieldLogger, discoveryService discovery.Service) *client { + eksClient := eks_client_mock.NewMockClient(ctrl) + eksClient.EXPECT().GetClusterName(ctx).Return(nil, fmt.Errorf("eks error")).Times(1) + + return &client{ + log: logrus.New(), + discoveryService: discoveryService, + eksClient: eksClient, + } + }, + expectedErr: lo.ToPtr("eks error"), + }, + "should use GKE client to determine cluster name when CSP is GCP": { + csp: cloud.GCP, + getClient: func(ctx context.Context, ctrl *gomock.Controller, log logrus.FieldLogger, discoveryService discovery.Service) *client { + gkeClient := gke_client_mock.NewMockMetadata(ctrl) + gkeClient.EXPECT().GetClusterName().Return("gke-cluster-name", nil).Times(1) + + return &client{ + log: logrus.New(), + discoveryService: discoveryService, + gkeMetadataClient: gkeClient, + } + }, + expectedClusterName: "gke-cluster-name", + }, + "should return GKE client's error when CSP is GCP and underlying client returns an error": { + csp: cloud.GCP, + getClient: func(ctx context.Context, ctrl *gomock.Controller, log logrus.FieldLogger, discoveryService discovery.Service) *client { + gkeClient := gke_client_mock.NewMockMetadata(ctrl) + gkeClient.EXPECT().GetClusterName().Return("", fmt.Errorf("gke error")).Times(1) + + return &client{ + log: logrus.New(), + discoveryService: discoveryService, + gkeMetadataClient: gkeClient, + } + }, + expectedErr: lo.ToPtr("gke error"), + }, + } + + for name, test := range tests { + test := test + t.Run(name, func(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + r := require.New(t) + + discoveryService := discovery_mock.NewMockService(ctrl) + + discoveryService.EXPECT().GetCSP(ctx).Return(test.csp, nil).Times(1) + + client := test.getClient(ctx, ctrl, logrus.New(), discoveryService) + + clusterName, err := client.GetClusterName(ctx) + + if test.expectedErr != nil { + r.ErrorContains(err, *test.expectedErr) + } else { + r.Equal(clusterName, test.expectedClusterName) + } + }) + } +} diff --git a/internal/services/providers/anywhere/client/mock/client.go b/internal/services/providers/anywhere/client/mock/client.go new file mode 100644 index 00000000..e7b77150 --- /dev/null +++ b/internal/services/providers/anywhere/client/mock/client.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: castai-agent/internal/services/providers/anywhere/client (interfaces: Client) + +// Package mock_client is a generated GoMock package. +package mock_client + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// GetClusterName mocks base method. +func (m *MockClient) GetClusterName(arg0 context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetClusterName", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetClusterName indicates an expected call of GetClusterName. +func (mr *MockClientMockRecorder) GetClusterName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterName", reflect.TypeOf((*MockClient)(nil).GetClusterName), arg0) +} diff --git a/internal/services/providers/kops/kops.go b/internal/services/providers/kops/kops.go index 5948e84e..ea93f693 100644 --- a/internal/services/providers/kops/kops.go +++ b/internal/services/providers/kops/kops.go @@ -35,7 +35,7 @@ type Provider struct { } func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { - clusterID, err := p.discoveryService.GetClusterID(ctx) + clusterID, err := p.discoveryService.GetKubeSystemNamespaceID(ctx) if err != nil { return nil, fmt.Errorf("getting cluster ID: %w", err) } diff --git a/internal/services/providers/openshift/openshift.go b/internal/services/providers/openshift/openshift.go index ad076a18..25250c7b 100644 --- a/internal/services/providers/openshift/openshift.go +++ b/internal/services/providers/openshift/openshift.go @@ -33,7 +33,7 @@ func New(discoveryService discovery.Service, dyno dynamic.Interface) *Provider { } func (p *Provider) RegisterCluster(ctx context.Context, client castai.Client) (*types.ClusterRegistration, error) { - clusterID, err := p.discoveryService.GetClusterID(ctx) + clusterID, err := p.discoveryService.GetKubeSystemNamespaceID(ctx) if err != nil { return nil, fmt.Errorf("getting cluster id: %w", err) } diff --git a/internal/services/providers/openshift/openshift_test.go b/internal/services/providers/openshift/openshift_test.go index 636ffe17..c372a36e 100644 --- a/internal/services/providers/openshift/openshift_test.go +++ b/internal/services/providers/openshift/openshift_test.go @@ -48,7 +48,7 @@ func TestProvider_RegisterCluster(t *testing.T) { r.NoError(os.Setenv("API_URL", "test")) }, discoveryServiceMock: func(t *testing.T, discoveryService *mock_discovery.MockService) { - discoveryService.EXPECT().GetClusterID(gomock.Any()).Return(®Req.ID, nil) + discoveryService.EXPECT().GetKubeSystemNamespaceID(gomock.Any()).Return(®Req.ID, nil) discoveryService.EXPECT().GetCSPAndRegion(gomock.Any()).Return(cloud.Cloud(regReq.Openshift.CSP), regReq.Openshift.Region, nil) discoveryService.EXPECT().GetOpenshiftClusterID(gomock.Any()).Return(regReq.Openshift.InternalID, nil) discoveryService.EXPECT().GetOpenshiftClusterName(gomock.Any()).Return(regReq.Name, nil) @@ -66,7 +66,7 @@ func TestProvider_RegisterCluster(t *testing.T) { r.NoError(os.Setenv("OPENSHIFT_INTERNAL_ID", regReq.Openshift.InternalID)) }, discoveryServiceMock: func(t *testing.T, discoveryService *mock_discovery.MockService) { - discoveryService.EXPECT().GetClusterID(gomock.Any()).Return(®Req.ID, nil) + discoveryService.EXPECT().GetKubeSystemNamespaceID(gomock.Any()).Return(®Req.ID, nil) }, }, } diff --git a/internal/services/providers/providers.go b/internal/services/providers/providers.go index 68398e18..ad3764bd 100644 --- a/internal/services/providers/providers.go +++ b/internal/services/providers/providers.go @@ -10,6 +10,8 @@ import ( "castai-agent/internal/config" "castai-agent/internal/services/discovery" "castai-agent/internal/services/providers/aks" + "castai-agent/internal/services/providers/anywhere" + anywhere_client "castai-agent/internal/services/providers/anywhere/client" "castai-agent/internal/services/providers/eks" "castai-agent/internal/services/providers/gke" "castai-agent/internal/services/providers/kops" @@ -43,6 +45,13 @@ func GetProvider(ctx context.Context, log logrus.FieldLogger, discoveryService d return aks.New(log.WithField("provider", aks.Name)) } + if cfg.Provider == anywhere.Name || cfg.Anywhere != nil { + logger := log.WithField("provider", cfg.Provider) + client := anywhere_client.New(log, discoveryService) + + return anywhere.New(discoveryService, client, logger), nil + } + if cfg.Provider == openshift.Name { return openshift.New(discoveryService, dyno), nil } diff --git a/internal/services/providers/providers_test.go b/internal/services/providers/providers_test.go index f7579c33..ddace68c 100644 --- a/internal/services/providers/providers_test.go +++ b/internal/services/providers/providers_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "castai-agent/internal/config" + "castai-agent/internal/services/providers/anywhere" "castai-agent/internal/services/providers/eks" "castai-agent/internal/services/providers/gke" "castai-agent/internal/services/providers/kops" @@ -86,6 +87,22 @@ func TestGetProvider(t *testing.T) { r.NoError(err) r.IsType(&openshift.Provider{}, got) }) + + t.Run("should return anywhere", func(t *testing.T) { + r := require.New(t) + + t.Cleanup(config.Reset) + t.Cleanup(os.Clearenv) + + r.NoError(os.Setenv("API_KEY", "api-key")) + r.NoError(os.Setenv("API_URL", "test")) + r.NoError(os.Setenv("PROVIDER", "anywhere")) + + got, err := GetProvider(context.Background(), logrus.New(), nil, nil) + + r.NoError(err) + r.IsType(&anywhere.Provider{}, got) + }) } func Test_isAPINodeLifecycleDiscoveryEnabled(t *testing.T) {