From 73edbcb5d209de93587bcae74edbbdfbc858b60b Mon Sep 17 00:00:00 2001 From: nadimabdelaziz Date: Sun, 12 Jan 2025 16:52:46 +0200 Subject: [PATCH] wip --- core/provider/digitalocean/client.go | 86 ++ .../digitalocean/digitalocean_provider.go | 165 ---- core/provider/digitalocean/docker.go | 94 ++ core/provider/digitalocean/droplet.go | 70 +- core/provider/digitalocean/firewall.go | 2 +- .../digitalocean/mocks/do_client_mock.go | 414 +++++++++ .../digitalocean/mocks/docker_client_mock.go | 377 ++++++++ core/provider/digitalocean/provider.go | 328 +++++++ core/provider/digitalocean/provider_test.go | 552 ++++++++++++ core/provider/digitalocean/ssh.go | 2 +- core/provider/digitalocean/tag.go | 2 +- core/provider/digitalocean/task.go | 285 +++--- core/provider/digitalocean/task_test.go | 828 ++++++++++++++++++ 13 files changed, 2822 insertions(+), 383 deletions(-) create mode 100644 core/provider/digitalocean/client.go delete mode 100644 core/provider/digitalocean/digitalocean_provider.go create mode 100644 core/provider/digitalocean/docker.go create mode 100644 core/provider/digitalocean/mocks/do_client_mock.go create mode 100644 core/provider/digitalocean/mocks/docker_client_mock.go create mode 100644 core/provider/digitalocean/provider.go create mode 100644 core/provider/digitalocean/provider_test.go create mode 100644 core/provider/digitalocean/task_test.go diff --git a/core/provider/digitalocean/client.go b/core/provider/digitalocean/client.go new file mode 100644 index 00000000..753abc6f --- /dev/null +++ b/core/provider/digitalocean/client.go @@ -0,0 +1,86 @@ +package digitalocean + +import ( + "context" + + "github.com/digitalocean/godo" +) + +// DoClient defines the interface for DigitalOcean API operations used by the provider +type DoClient interface { + // Droplet operations + CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) + GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) + DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) + DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) + + // Firewall operations + CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) + DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) + + // SSH Key operations + CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) + DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) + GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) + + // Tag operations + CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) + DeleteTag(ctx context.Context, tag string) (*godo.Response, error) +} + +// godoClient implements the DoClient interface using the actual godo.Client +type godoClient struct { + *godo.Client +} + +func NewGodoClient(token string) DoClient { + return &godoClient{Client: godo.NewFromToken(token)} +} + +// Droplet operations +func (c *godoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { + return c.Droplets.Create(ctx, req) +} + +func (c *godoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { + return c.Droplets.Get(ctx, dropletID) +} + +func (c *godoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { + return c.Droplets.DeleteByTag(ctx, tag) +} + +func (c *godoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { + return c.Droplets.Delete(ctx, id) +} + +// Firewall operations +func (c *godoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { + return c.Firewalls.Create(ctx, req) +} + +func (c *godoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { + return c.Firewalls.Delete(ctx, firewallID) +} + +// SSH Key operations +func (c *godoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { + return c.Keys.Create(ctx, req) +} + +func (c *godoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { + return c.Keys.DeleteByFingerprint(ctx, fingerprint) +} + +func (c *godoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { + return c.Keys.GetByFingerprint(ctx, fingerprint) +} + +// Tag operations +func (c *godoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + return c.Tags.Create(ctx, req) +} + +func (c *godoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { + return c.Tags.Delete(ctx, tag) +} diff --git a/core/provider/digitalocean/digitalocean_provider.go b/core/provider/digitalocean/digitalocean_provider.go deleted file mode 100644 index c3540d26..00000000 --- a/core/provider/digitalocean/digitalocean_provider.go +++ /dev/null @@ -1,165 +0,0 @@ -package digitalocean - -import ( - "context" - "fmt" - "strings" - - "github.com/digitalocean/godo" - "go.uber.org/zap" - - xsync "github.com/puzpuzpuz/xsync/v3" - - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/util" - "golang.org/x/crypto/ssh" -) - -var _ provider.Provider = (*Provider)(nil) - -const ( - providerLabelName = "petri-provider" -) - -type Provider struct { - logger *zap.Logger - name string - doClient *godo.Client - petriTag string - - userIPs []string - - sshKeyPair *SSHKeyPair - - sshClients *xsync.MapOf[string, *ssh.Client] - - firewallID string -} - -// NewDigitalOceanProvider creates a provider that implements the Provider interface for DigitalOcean. -// Token is the DigitalOcean API token -func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { - doClient := godo.NewFromToken(token) - - if sshKeyPair == nil { - newSshKeyPair, err := MakeSSHKeyPair() - if err != nil { - return nil, err - } - sshKeyPair = newSshKeyPair - } - - userIPs, err := getUserIPs(ctx) - if err != nil { - return nil, err - } - - userIPs = append(userIPs, additionalUserIPS...) - - digitalOceanProvider := &Provider{ - logger: logger.Named("digitalocean_provider"), - name: providerName, - doClient: doClient, - petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), - - userIPs: userIPs, - - sshClients: xsync.NewMapOf[string, *ssh.Client](), - sshKeyPair: sshKeyPair, - } - - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) - if err != nil { - return nil, err - } - - firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) - if err != nil { - return nil, fmt.Errorf("failed to create firewall: %w", err) - } - - digitalOceanProvider.firewallID = firewall.ID - - //TODO(Zygimantass): TOCTOU issue - if key, _, err := doClient.Keys.GetByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { - _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) - if err != nil { - if !strings.Contains(err.Error(), "422") { - return nil, err - } - } - } - - return digitalOceanProvider, nil -} - -func (p *Provider) Teardown(ctx context.Context) error { - p.logger.Info("tearing down DigitalOcean provider") - - if err := p.teardownTasks(ctx); err != nil { - return err - } - if err := p.teardownFirewall(ctx); err != nil { - return err - } - if err := p.teardownSSHKey(ctx); err != nil { - return err - } - if err := p.teardownTag(ctx); err != nil { - return err - } - - return nil -} - -func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.Droplets.DeleteByTag(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.Firewalls.Delete(ctx, p.firewallID) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.Keys.DeleteByFingerprint(ctx, p.sshKeyPair.Fingerprint) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.Tags.Delete(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} diff --git a/core/provider/digitalocean/docker.go b/core/provider/digitalocean/docker.go new file mode 100644 index 00000000..10a6597a --- /dev/null +++ b/core/provider/digitalocean/docker.go @@ -0,0 +1,94 @@ +package digitalocean + +import ( + "context" + "io" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/network" + dockerclient "github.com/docker/docker/client" + specs "github.com/opencontainers/image-spec/specs-go/v1" +) + +// DockerClient is an interface that abstracts Docker functionality +type DockerClient interface { + Ping(ctx context.Context) (types.Ping, error) + ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) + ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) + ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) + ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error + ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error + ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) + ContainerExecCreate(ctx context.Context, container string, config types.ExecConfig) (types.IDResponse, error) + ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) + ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) + ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error + Close() error +} + +type defaultDockerClient struct { + client *dockerclient.Client +} + +func NewDockerClient(host string) (DockerClient, error) { + client, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(host)) + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil +} + +func (d *defaultDockerClient) Ping(ctx context.Context) (types.Ping, error) { + return d.client.Ping(ctx) +} + +func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) { + return d.client.ImageInspectWithRaw(ctx, image) +} + +func (d *defaultDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + return d.client.ImagePull(ctx, ref, options) +} + +func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { + return d.client.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, containerName) +} + +func (d *defaultDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + return d.client.ContainerList(ctx, options) +} + +func (d *defaultDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + return d.client.ContainerStart(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + return d.client.ContainerStop(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + return d.client.ContainerInspect(ctx, containerID) +} + +func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, options container.ExecOptions) (types.IDResponse, error) { + return d.client.ContainerExecCreate(ctx, container, options) +} + +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecAttachOptions) (types.HijackedResponse, error) { + return d.client.ContainerExecAttach(ctx, execID, config) +} + +func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { + return d.client.ContainerExecInspect(ctx, execID) +} + +func (d *defaultDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + return d.client.ContainerRemove(ctx, containerID, options) +} + +func (d *defaultDockerClient) Close() error { + return d.client.Close() +} diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 769b87f9..0f0a2e35 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -3,11 +3,11 @@ package digitalocean import ( "context" "fmt" - "github.com/pkg/errors" "time" + "github.com/pkg/errors" + "github.com/digitalocean/godo" - dockerclient "github.com/docker/docker/client" "go.uber.org/zap" "golang.org/x/crypto/ssh" @@ -56,7 +56,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe Tags: []string{p.petriTag}, } - droplet, res, err := p.doClient.Droplets.Create(ctx, req) + droplet, res, err := p.doClient.CreateDroplet(ctx, req) if err != nil { return nil, err } @@ -68,7 +68,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe start := time.Now() err = util.WaitForCondition(ctx, time.Second*600, time.Millisecond*300, func() (bool, error) { - d, _, err := p.doClient.Droplets.Get(ctx, droplet.ID) + d, _, err := p.doClient.GetDroplet(ctx, droplet.ID) if err != nil { return false, err } @@ -82,13 +82,16 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return false, nil } - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) - if err != nil { - p.logger.Error("failed to create docker client", zap.Error(err)) - return false, err + if p.dockerClients[ip] == nil { + dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + p.logger.Error("failed to create docker client", zap.Error(err)) + return false, err + } + p.dockerClients[ip] = dockerClient } - _, err = dockerClient.Ping(ctx) + _, err = p.dockerClients[ip].Ping(ctx) if err != nil { return false, nil } @@ -108,14 +111,14 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return droplet, nil } -func (p *Provider) deleteDroplet(ctx context.Context, name string) error { - droplet, err := p.getDroplet(ctx, name) +func (t *Task) deleteDroplet(ctx context.Context) error { + droplet, err := t.getDroplet(ctx) if err != nil { return err } - res, err := p.doClient.Droplets.Delete(ctx, droplet.ID) + res, err := t.doClient.DeleteDropletByID(ctx, droplet.ID) if err != nil { return err } @@ -127,11 +130,8 @@ func (p *Provider) deleteDroplet(ctx context.Context, name string) error { return nil } -func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, error) { - // TODO(Zygimantass): this change assumes that all Petri droplets are unique by name - // which should be technically true, but there might be edge cases where it's not. - droplets, res, err := p.doClient.Droplets.ListByName(ctx, name, nil) - +func (t *Task) getDroplet(ctx context.Context) (*godo.Droplet, error) { + droplet, res, err := t.doClient.GetDroplet(ctx, t.state.ID) if err != nil { return nil, err } @@ -140,46 +140,28 @@ func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) } - if len(droplets) == 0 { - return nil, fmt.Errorf("could not find droplet") - } - - return &droplets[0], nil -} - -func (p *Provider) getDropletDockerClient(ctx context.Context, taskName string) (*dockerclient.Client, error) { - ip, err := p.GetIP(ctx, taskName) - if err != nil { - return nil, err - } - - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) - if err != nil { - return nil, err - } - - return dockerClient, nil + return droplet, nil } -func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { - if _, err := p.getDroplet(ctx, taskName); err != nil { +func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { + if _, err := t.getDroplet(ctx); err != nil { return nil, fmt.Errorf("droplet %s does not exist", taskName) } - if client, ok := p.sshClients.Load(taskName); ok { - status, _, err := client.SendRequest("ping", true, []byte("ping")) + if t.sshClient != nil { + status, _, err := t.sshClient.SendRequest("ping", true, []byte("ping")) if err == nil && status { - return client, nil + return t.sshClient, nil } } - ip, err := p.GetIP(ctx, taskName) + ip, err := t.GetIP(ctx) if err != nil { return nil, err } - parsedSSHKey, err := ssh.ParsePrivateKey([]byte(p.sshKeyPair.PrivateKey)) + parsedSSHKey, err := ssh.ParsePrivateKey([]byte(t.sshKeyPair.PrivateKey)) if err != nil { return nil, err } @@ -201,7 +183,5 @@ func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*s return nil, err } - p.sshClients.Store(taskName, client) - return client, nil } diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index 82917b5a..f8c73d3c 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -54,7 +54,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go }, } - firewall, res, err := p.doClient.Firewalls.Create(ctx, req) + firewall, res, err := p.doClient.CreateFirewall(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/mocks/do_client_mock.go b/core/provider/digitalocean/mocks/do_client_mock.go new file mode 100644 index 00000000..a94deb7c --- /dev/null +++ b/core/provider/digitalocean/mocks/do_client_mock.go @@ -0,0 +1,414 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + godo "github.com/digitalocean/godo" + + mock "github.com/stretchr/testify/mock" +) + +// DoClient is an autogenerated mock type for the DoClient type +type DoClient struct { + mock.Mock +} + +// CreateDroplet provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDroplet") + } + + var r0 *godo.Droplet + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) *godo.Droplet); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.DropletCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.DropletCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateFirewall provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateFirewall") + } + + var r0 *godo.Firewall + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) *godo.Firewall); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Firewall) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.FirewallRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.FirewallRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateKey provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateKey") + } + + var r0 *godo.Key + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) *godo.Key); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.KeyCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.KeyCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateTag provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateTag") + } + + var r0 *godo.Tag + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) *godo.Tag); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.TagCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.TagCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// DeleteDropletByID provides a mock function with given fields: ctx, id +func (_m *DoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByID") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Response, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Response); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteDropletByTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByTag") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, tag) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteFirewall provides a mock function with given fields: ctx, firewallID +func (_m *DoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { + ret := _m.Called(ctx, firewallID) + + if len(ret) == 0 { + panic("no return value specified for DeleteFirewall") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, firewallID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, firewallID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, firewallID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for DeleteKeyByFingerprint") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, fingerprint) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, fingerprint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, fingerprint) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, tag) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDroplet provides a mock function with given fields: ctx, dropletID +func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { + ret := _m.Called(ctx, dropletID) + + if len(ret) == 0 { + panic("no return value specified for GetDroplet") + } + + var r0 *godo.Droplet + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Droplet, *godo.Response, error)); ok { + return rf(ctx, dropletID) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Droplet); ok { + r0 = rf(ctx, dropletID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) *godo.Response); ok { + r1 = rf(ctx, dropletID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, int) error); ok { + r2 = rf(ctx, dropletID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for GetKeyByFingerprint") + } + + var r0 *godo.Key + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Key, *godo.Response, error)); ok { + return rf(ctx, fingerprint) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Key); ok { + r0 = rf(ctx, fingerprint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) *godo.Response); ok { + r1 = rf(ctx, fingerprint) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, fingerprint) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// NewDoClient creates a new instance of DoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDoClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DoClient { + mock := &DoClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/digitalocean/mocks/docker_client_mock.go b/core/provider/digitalocean/mocks/docker_client_mock.go new file mode 100644 index 00000000..e27c4705 --- /dev/null +++ b/core/provider/digitalocean/mocks/docker_client_mock.go @@ -0,0 +1,377 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + container "github.com/docker/docker/api/types/container" + + image "github.com/docker/docker/api/types/image" + + io "io" + + mock "github.com/stretchr/testify/mock" + + network "github.com/docker/docker/api/types/network" + + types "github.com/docker/docker/api/types" + + v1 "github.com/opencontainers/image-spec/specs-go/v1" +) + +// DockerClient is an autogenerated mock type for the DockerClient type +type DockerClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DockerClient) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerCreate provides a mock function with given fields: ctx, config, hostConfig, networkingConfig, platform, containerName +func (_m *DockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *v1.Platform, containerName string) (container.CreateResponse, error) { + ret := _m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) + + if len(ret) == 0 { + panic("no return value specified for ContainerCreate") + } + + var r0 container.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) (container.CreateResponse, error)); ok { + return rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) container.CreateResponse); ok { + r0 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r0 = ret.Get(0).(container.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) error); ok { + r1 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecAttach provides a mock function with given fields: ctx, execID, config +func (_m *DockerClient) ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) { + ret := _m.Called(ctx, execID, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecAttach") + } + + var r0 types.HijackedResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) (types.HijackedResponse, error)); ok { + return rf(ctx, execID, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) types.HijackedResponse); ok { + r0 = rf(ctx, execID, config) + } else { + r0 = ret.Get(0).(types.HijackedResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecStartCheck) error); ok { + r1 = rf(ctx, execID, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecCreate provides a mock function with given fields: ctx, _a1, config +func (_m *DockerClient) ContainerExecCreate(ctx context.Context, _a1 string, config types.ExecConfig) (types.IDResponse, error) { + ret := _m.Called(ctx, _a1, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecCreate") + } + + var r0 types.IDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) (types.IDResponse, error)); ok { + return rf(ctx, _a1, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) types.IDResponse); ok { + r0 = rf(ctx, _a1, config) + } else { + r0 = ret.Get(0).(types.IDResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecConfig) error); ok { + r1 = rf(ctx, _a1, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecInspect provides a mock function with given fields: ctx, execID +func (_m *DockerClient) ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) { + ret := _m.Called(ctx, execID) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecInspect") + } + + var r0 types.ContainerExecInspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerExecInspect, error)); ok { + return rf(ctx, execID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerExecInspect); ok { + r0 = rf(ctx, execID) + } else { + r0 = ret.Get(0).(types.ContainerExecInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, execID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerInspect provides a mock function with given fields: ctx, containerID +func (_m *DockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + ret := _m.Called(ctx, containerID) + + if len(ret) == 0 { + panic("no return value specified for ContainerInspect") + } + + var r0 types.ContainerJSON + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerJSON, error)); ok { + return rf(ctx, containerID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerJSON); ok { + r0 = rf(ctx, containerID) + } else { + r0 = ret.Get(0).(types.ContainerJSON) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, containerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerList provides a mock function with given fields: ctx, options +func (_m *DockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerList") + } + + var r0 []types.Container + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) ([]types.Container, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) []types.Container); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Container) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, container.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerRemove provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.RemoveOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStart provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StartOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStop provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStop") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StopOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ImageInspectWithRaw provides a mock function with given fields: ctx, _a1 +func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, _a1 string) (types.ImageInspect, []byte, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for ImageInspectWithRaw") + } + + var r0 types.ImageInspect + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ImageInspect, []byte, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ImageInspect); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(types.ImageInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) []byte); ok { + r1 = rf(ctx, _a1) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, _a1) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ImagePull provides a mock function with given fields: ctx, ref, options +func (_m *DockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + ret := _m.Called(ctx, ref, options) + + if len(ret) == 0 { + panic("no return value specified for ImagePull") + } + + var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) (io.ReadCloser, error)); ok { + return rf(ctx, ref, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) io.ReadCloser); ok { + r0 = rf(ctx, ref, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, image.PullOptions) error); ok { + r1 = rf(ctx, ref, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Ping provides a mock function with given fields: ctx +func (_m *DockerClient) Ping(ctx context.Context) (types.Ping, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 types.Ping + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (types.Ping, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) types.Ping); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(types.Ping) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewDockerClient creates a new instance of DockerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDockerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DockerClient { + mock := &DockerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go new file mode 100644 index 00000000..713f465e --- /dev/null +++ b/core/provider/digitalocean/provider.go @@ -0,0 +1,328 @@ +package digitalocean + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/util" +) + +var _ provider.ProviderI = (*Provider)(nil) + +const ( + providerLabelName = "petri-provider" + sshPort = "2375" +) + +type ProviderState struct { + TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state + Name string `json:"name"` +} + +type Provider struct { + state *ProviderState + stateMu sync.Mutex + + logger *zap.Logger + name string + doClient DoClient + petriTag string + userIPs []string + sshKeyPair *SSHKeyPair + firewallID string + dockerClients map[string]DockerClient // map of droplet ip address to docker clients +} + +// NewProvider creates a provider that implements the Provider interface for DigitalOcean. +// Token is the DigitalOcean API token +func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + doClient := NewGodoClient(token) + return NewProviderWithClient(ctx, logger, providerName, doClient, nil, additionalUserIPS, sshKeyPair) +} + +// NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. +// This is primarily used for testing. +func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + if sshKeyPair == nil { + newSshKeyPair, err := MakeSSHKeyPair() + if err != nil { + return nil, err + } + sshKeyPair = newSshKeyPair + } + + userIPs, err := getUserIPs(ctx) + if err != nil { + return nil, err + } + + userIPs = append(userIPs, additionalUserIPS...) + + digitalOceanProvider := &Provider{ + logger: logger.Named("digitalocean_provider"), + name: providerName, + doClient: doClient, + petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), + userIPs: userIPs, + sshKeyPair: sshKeyPair, + dockerClients: dockerClients, + state: &ProviderState{TaskStates: make(map[int]*TaskState)}, + } + + _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) + if err != nil { + return nil, err + } + + firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) + if err != nil { + return nil, fmt.Errorf("failed to create firewall: %w", err) + } + + digitalOceanProvider.firewallID = firewall.ID + + //TODO(Zygimantass): TOCTOU issue + if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { + _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) + if err != nil { + if !strings.Contains(err.Error(), "422") { + return nil, err + } + } + } + + return digitalOceanProvider, nil +} + +func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefinition) (provider.TaskI, error) { + if err := definition.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate task definition: %w", err) + } + + if definition.ProviderSpecificConfig == nil { + return nil, fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) + } + + var doConfig DigitalOceanTaskConfig + doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) + if !ok { + return nil, fmt.Errorf("invalid provider specific config type for %s", definition.Name) + } + + if err := doConfig.ValidateBasic(); err != nil { + return nil, fmt.Errorf("could not cast digitalocean specific config: %w", err) + } + + p.logger.Info("creating droplet", zap.String("name", definition.Name)) + + droplet, err := p.CreateDroplet(ctx, definition) + if err != nil { + return nil, err + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return nil, err + } + + p.logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) + + if p.dockerClients[ip] == nil { + p.dockerClients[ip], err = NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + return nil, err + } + } + + _, _, err = p.dockerClients[ip].ImageInspectWithRaw(ctx, definition.Image.Image) + if err != nil { + p.logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) + err = pullImage(ctx, p.dockerClients[ip], p.logger, definition.Image.Image) + if err != nil { + return nil, err + } + } + + _, err = p.dockerClients[ip].ContainerCreate(ctx, &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Labels: map[string]string{ + providerLabelName: p.name, + }, + Env: convertEnvMapToList(definition.Environment), + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: definition.DataDir, + }, + }, + NetworkMode: container.NetworkMode("host"), + }, nil, nil, definition.ContainerName) + if err != nil { + return nil, err + } + + taskState := &TaskState{ + ID: droplet.ID, + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + ProviderName: p.name, + } + + p.stateMu.Lock() + defer p.stateMu.Unlock() + + p.state.TaskStates[taskState.ID] = taskState + + return &Task{ + state: taskState, + provider: p, + sshKeyPair: p.sshKeyPair, + logger: p.logger.With(zap.String("task", definition.Name)), + doClient: p.doClient, + dockerClient: p.dockerClients[ip], + }, nil +} + +func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + bz, err := json.Marshal(p.state) + + return bz, err +} + +func (p *Provider) DeserializeProvider(context.Context) ([]byte, error) { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + bz, err := json.Marshal(p.state) + + return bz, err +} + +func (p *Provider) SerializeTask(ctx context.Context, task provider.TaskI) ([]byte, error) { + if _, ok := task.(*Task); !ok { + return nil, fmt.Errorf("task is not a Docker task") + } + + dockerTask := task.(*Task) + + bz, err := json.Marshal(dockerTask.state) + + if err != nil { + return nil, err + } + + return bz, nil +} + +func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.TaskI, error) { + var taskState TaskState + + err := json.Unmarshal(bz, &taskState) + if err != nil { + return nil, err + } + + task := &Task{ + state: &taskState, + } + + return task, nil +} + +func (p *Provider) Teardown(ctx context.Context) error { + p.logger.Info("tearing down DigitalOcean provider") + + if err := p.teardownTasks(ctx); err != nil { + return err + } + if err := p.teardownFirewall(ctx); err != nil { + return err + } + if err := p.teardownSSHKey(ctx); err != nil { + return err + } + if err := p.teardownTag(ctx); err != nil { + return err + } + return nil +} + +func (p *Provider) teardownTasks(ctx context.Context) error { + res, err := p.doClient.DeleteDropletByTag(ctx, p.petriTag) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownFirewall(ctx context.Context) error { + res, err := p.doClient.DeleteFirewall(ctx, p.firewallID) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownSSHKey(ctx context.Context) error { + res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.sshKeyPair.Fingerprint) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownTag(ctx context.Context) error { + res, err := p.doClient.DeleteTag(ctx, p.petriTag) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) removeTask(_ context.Context, taskID int) error { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + delete(p.state.TaskStates, taskID) + + return nil +} diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go new file mode 100644 index 00000000..af84b079 --- /dev/null +++ b/core/provider/digitalocean/provider_test.go @@ -0,0 +1,552 @@ +package digitalocean + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/skip-mev/petri/core/v2/util" +) + +func setupMockClient(t *testing.T) *mocks.DoClient { + return mocks.NewDoClient(t) +} + +func setupMockDockerClient(t *testing.T) *mocks.DockerClient { + return mocks.NewDockerClient(t) +} + +func TestNewProvider(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + + tests := []struct { + name string + token string + additionalIPs []string + sshKeyPair *SSHKeyPair + expectedError bool + mockSetup func(*mocks.DoClient) + }{ + { + name: "valid provider creation", + token: "test-token", + additionalIPs: []string{"1.2.3.4"}, + sshKeyPair: nil, + expectedError: false, + mockSetup: func(m *mocks.DoClient) { + m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + m.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + m.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + m.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + }, + }, + { + name: "bad token", + token: "foobar", + additionalIPs: []string{}, + sshKeyPair: nil, + expectedError: true, + mockSetup: func(m *mocks.DoClient) { + m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusUnauthorized}}, fmt.Errorf("unauthorized")) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + tc.mockSetup(mockClient) + mockDockerClients := map[string]DockerClient{ + "test-ip": mockDocker, + } + + provider, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, tc.additionalIPs, tc.sshKeyPair) + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.NotEmpty(t, provider.petriTag) + assert.NotEmpty(t, provider.firewallID) + } + }) + } +} + +func TestCreateTask(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDockerClients := map[string]DockerClient{ + "10.0.0.1": mockDocker, + } + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + tests := []struct { + name string + taskDef provider.TaskDefinition + expectedError bool + }{ + { + name: "valid task creation", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + }, + expectedError: false, + }, + { + name: "missing provider specific config", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: nil, + }, + expectedError: true, + }, + { + name: "invalid provider specific config - missing region", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "image_id": "123456", + }, + }, + expectedError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + task, err := p.CreateTask(ctx, tc.taskDef) + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, task) + } else { + assert.NoError(t, err) + assert.NotNil(t, task) + } + }) + } +} + +func TestSerializeAndRestore(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + mockDockerClients := map[string]DockerClient{ + "10.0.0.1": mockDocker, + } + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + providerData, err := p.SerializeProvider(ctx) + assert.NoError(t, err) + assert.NotNil(t, providerData) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + require.NoError(t, err) + + taskData, err := p.SerializeTask(ctx, task) + assert.NoError(t, err) + assert.NotNil(t, taskData) + + deserializedTask, err := p.DeserializeTask(ctx, taskData) + assert.NoError(t, err) + assert.NotNil(t, deserializedTask) +} + +func TestTeardown(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, nil, []string{}, nil) + require.NoError(t, err) + + err = p.Teardown(ctx) + assert.NoError(t, err) + mockClient.AssertExpectations(t) + mockDocker.AssertExpectations(t) +} + +func TestConcurrentTaskCreationAndCleanup(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + logger, _ := zap.NewDevelopment() + mockDockerClients := make(map[string]DockerClient) + mockDO := mocks.NewDoClient(t) + + for i := 0; i < 10; i++ { + ip := fmt.Sprintf("10.0.0.%d", i+1) + mockDocker := mocks.NewDockerClient(t) + mockDockerClients[ip] = mockDocker + + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil).Once() + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil).Once() + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), + mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), + mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{ + { + ID: fmt.Sprintf("container-%d", i), + State: "running", + }, + }, nil).Times(3) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil).Maybe() + mockDocker.On("Close").Return(nil).Once() + } + + mockDO.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockDO.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + numTasks := 10 + var wg sync.WaitGroup + errors := make(chan error, numTasks) + tasks := make(chan *Task, numTasks) + taskMutex := sync.Mutex{} + dropletIDs := make(map[int]bool) + ipAddresses := make(map[string]bool) + + for i := 0; i < numTasks; i++ { + dropletID := 1000 + i + droplet := &godo.Droplet{ + ID: dropletID, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: fmt.Sprintf("10.0.0.%d", i+1), + }, + }, + }, + } + + mockDO.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusCreated}}, nil).Once() + // we cant predict how many times GetDroplet will be called exactly as the provider polls waiting for its creation + mockDO.On("GetDroplet", mock.Anything, dropletID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO.On("DeleteDropletByID", mock.Anything, dropletID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() + } + + // these are called once per provider, not per task + mockDO.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + + for i := 0; i < numTasks; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + task, err := p.CreateTask(ctx, provider.TaskDefinition{ + Name: fmt.Sprintf("test-task-%d", index), + ContainerName: fmt.Sprintf("test-container-%d", index), + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456789", // Using a string for image_id + }, + }) + + if err != nil { + errors <- fmt.Errorf("task creation error: %v", err) + return + } + + if err := task.Start(ctx); err != nil { + errors <- fmt.Errorf("task start error: %v", err) + return + } + + // thread-safe recording of task details + taskMutex.Lock() + doTask := task.(*Task) + state := doTask.GetState() + + // check for duplicate droplet IDs or IP addresses + if dropletIDs[state.ID] { + errors <- fmt.Errorf("duplicate droplet ID found: %d", state.ID) + } + dropletIDs[state.ID] = true + + ip, err := task.GetIP(ctx) + if err == nil { + if ipAddresses[ip] { + errors <- fmt.Errorf("duplicate IP address found: %s", ip) + } + ipAddresses[ip] = true + } + + taskMutex.Unlock() + + tasks <- doTask + }(i) + } + + wg.Wait() + close(errors) + + // Check for any task creation errors before proceeding + for err := range errors { + require.NoError(t, err) + } + + require.Equal(t, numTasks, len(p.state.TaskStates), "Provider state should contain all tasks") + + // Collect all tasks in a slice before closing the channel + var tasksToCleanup []*Task + close(tasks) + for task := range tasks { + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status, "All tasks should be in running state") + + state := task.GetState() + require.NotEmpty(t, state.ID, "Task should have a droplet ID") + require.NotEmpty(t, state.Name, "Task should have a name") + require.NotEmpty(t, state.Definition.ContainerName, "Task should have a container name") + tasksToCleanup = append(tasksToCleanup, task) + } + + // test cleanup + var cleanupWg sync.WaitGroup + cleanupErrors := make(chan error, numTasks) + + for _, task := range tasksToCleanup { + cleanupWg.Add(1) + go func(t *Task) { + defer cleanupWg.Done() + if err := t.Destroy(ctx); err != nil { + cleanupErrors <- fmt.Errorf("cleanup error: %v", err) + return + } + + // Wait for the task to be removed from the provider state + err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { + taskMutex.Lock() + defer taskMutex.Unlock() + _, exists := p.state.TaskStates[t.GetState().ID] + return !exists, nil + }) + if err != nil { + cleanupErrors <- fmt.Errorf("task state cleanup error: %v", err) + } + }(task) + } + + cleanupWg.Wait() + close(cleanupErrors) + + for err := range cleanupErrors { + require.NoError(t, err) + } + + // Wait for provider state to be empty + err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { + return len(p.state.TaskStates) == 0, nil + }) + require.NoError(t, err, "Provider state should be empty after cleanup") + + // Teardown the provider to clean up resources + err = p.Teardown(ctx) + require.NoError(t, err) + + mockDO.AssertExpectations(t) + for _, client := range mockDockerClients { + client.(*mocks.DockerClient).AssertExpectations(t) + } +} diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index dbcdc0a6..0b604dee 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -98,7 +98,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.petriTag)} - key, res, err := p.doClient.Keys.Create(ctx, req) + key, res, err := p.doClient.CreateKey(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/tag.go b/core/provider/digitalocean/tag.go index db1e9fae..9d8654fb 100644 --- a/core/provider/digitalocean/tag.go +++ b/core/provider/digitalocean/tag.go @@ -12,7 +12,7 @@ func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, er Name: tagName, } - tag, res, err := p.doClient.Tags.Create(ctx, req) + tag, res, err := p.doClient.CreateTag(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index e3044ba0..bb42f370 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -7,13 +7,17 @@ import ( "io" "net" "path" + "sync" "time" + "golang.org/x/crypto/ssh" + "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" dockerclient "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" + "github.com/pkg/errors" "github.com/pkg/sftp" "github.com/spf13/afero" "github.com/spf13/afero/sftpfs" @@ -23,88 +27,30 @@ import ( "github.com/skip-mev/petri/core/v2/util" ) -func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) { - if err := definition.ValidateBasic(); err != nil { - return "", fmt.Errorf("failed to validate task definition: %w", err) - } - - if definition.ProviderSpecificConfig == nil { - return "", fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) - } - - doConfig := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) - - if err := doConfig.ValidateBasic(); err != nil { - return "", fmt.Errorf("could not cast digitalocean specific config: %w", err) - } - - logger = logger.Named("digitalocean_provider") - - logger.Info("creating droplet", zap.String("name", definition.Name)) - - droplet, err := p.CreateDroplet(ctx, definition) - if err != nil { - return "", err - } - - ip, err := p.GetIP(ctx, droplet.Name) - if err != nil { - return "", err - } - - logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) - - dockerClient, err := p.getDropletDockerClient(ctx, droplet.Name) - defer dockerClient.Close() // nolint - - if err != nil { - return "", err - } - - _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) - if err != nil { - logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) - if err := p.pullImage(ctx, dockerClient, definition.Image.Image); err != nil { - return "", err - } - } +type TaskState struct { + ID int `json:"id"` + Name string `json:"name"` + Definition provider.TaskDefinition `json:"definition"` + Status provider.TaskStatus `json:"status"` + ProviderName string `json:"provider_name"` +} - _, err = dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, - Tty: false, - Hostname: definition.Name, - Labels: map[string]string{ - providerLabelName: p.name, - }, - Env: convertEnvMapToList(definition.Environment), - }, &container.HostConfig{ - Mounts: []mount.Mount{ - { - Type: mount.TypeBind, - Source: "/docker_volumes", - Target: definition.DataDir, - }, - }, - NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) - if err != nil { - return "", err - } +type Task struct { + state *TaskState + stateMu sync.Mutex - return droplet.Name, nil + provider *Provider + logger *zap.Logger + sshKeyPair *SSHKeyPair + sshClient *ssh.Client + doClient DoClient + dockerClient DockerClient } -func (p *Provider) StartTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint +var _ provider.TaskI = (*Task)(nil) - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Start(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -113,23 +59,27 @@ func (p *Provider) StartTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.state.Name) } containerID := containers[0].ID - err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) + err = t.dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) if err != nil { return err } err = util.WaitForCondition(ctx, time.Second*300, time.Millisecond*100, func() (bool, error) { - status, err := p.GetTaskStatus(ctx, taskName) + status, err := t.GetStatus(ctx) if err != nil { return false, err } if status == provider.TASK_RUNNING { + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_RUNNING return true, nil } @@ -139,14 +89,8 @@ func (p *Provider) StartTask(ctx context.Context, taskName string) error { return err } -func (p *Provider) StopTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Stop(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -155,30 +99,49 @@ func (p *Provider) StopTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.state.Name) } - return dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_STOPPED + return t.dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) } -func (p *Provider) ModifyTask(ctx context.Context, taskName string, definition provider.TaskDefinition) error { - return nil +func (t *Task) Initialize(ctx context.Context) error { + panic("implement me") +} + +func (t *Task) Modify(ctx context.Context, definition provider.TaskDefinition) error { + panic("implement me") } -func (p *Provider) DestroyTask(ctx context.Context, taskName string) error { - logger := p.logger.With(zap.String("task", taskName)) +func (t *Task) Destroy(ctx context.Context) error { + logger := t.logger.With(zap.String("task", t.state.Name)) logger.Info("deleting task") + defer t.dockerClient.Close() - err := p.deleteDroplet(ctx, taskName) + err := t.deleteDroplet(ctx) if err != nil { return err } + // TODO(nadim-az): remove reference to provider in Task struct + if err := t.provider.removeTask(ctx, t.state.ID); err != nil { + return err + } return nil } -func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider.TaskStatus, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetState() TaskState { + t.stateMu.Lock() + defer t.stateMu.Unlock() + return *t.state +} + +func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -187,14 +150,7 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STOPPED, nil } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return provider.TASK_STATUS_UNDEFINED, err - } - - defer dockerClient.Close() - - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -203,23 +159,21 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider } if len(containers) != 1 { - return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", taskName) + return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", t.state.Name) } - container, err := dockerClient.ContainerInspect(ctx, containers[0].ID) + c, err := t.dockerClient.ContainerInspect(ctx, containers[0].ID) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } - switch state := container.State.Status; state { + switch state := c.State.Status; state { case "created": return provider.TASK_STOPPED, nil case "running": return provider.TASK_RUNNING, nil case "paused": return provider.TASK_PAUSED, nil - case "restarting": - return provider.TASK_RUNNING, nil // todo(zygimantass): is this sane? case "removing": return provider.TASK_STOPPED, nil case "exited": @@ -231,10 +185,10 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STATUS_UNDEFINED, nil } -func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath string, content []byte) error { +func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) if err != nil { return err } @@ -266,10 +220,10 @@ func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath strin return nil } -func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string) ([]byte, error) { +func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) if err != nil { return nil, err } @@ -293,12 +247,12 @@ func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string return content, nil } -func (p *Provider) DownloadDir(ctx context.Context, s string, s2 string, s3 string) error { +func (t *Task) DownloadDir(ctx context.Context, s string, s2 string) error { panic("implement me") } -func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetIP(ctx context.Context) (string, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return "", err @@ -307,8 +261,8 @@ func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { return droplet.PublicIPv4() } -func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port string) (string, error) { - ip, err := p.GetIP(ctx, taskName) +func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { + ip, err := t.GetIP(ctx) if err != nil { return "", err } @@ -316,14 +270,8 @@ func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port return net.JoinHostPort(ip, port), nil } -func (p *Provider) RunCommand(ctx context.Context, taskName string, command []string) (string, string, int, error) { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return "", "", 0, err - } - - defer dockerClient.Close() - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) RunCommand(ctx context.Context, command []string) (string, string, int, error) { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -332,14 +280,14 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st } if len(containers) != 1 { - return "", "", 0, fmt.Errorf("could not find container for %s", taskName) + return "", "", 0, fmt.Errorf("could not find container for %s", t.state.Name) } id := containers[0].ID - p.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) + t.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) - exec, err := dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: command, @@ -348,7 +296,7 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { return "", "", 0, err } @@ -365,7 +313,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -390,77 +338,74 @@ loop: return stdout.String(), stderr.String(), lastExitCode, nil } -func (p *Provider) RunCommandWhileStopped(ctx context.Context, taskName string, definition provider.TaskDefinition, command []string) (string, string, int, error) { - if err := definition.ValidateBasic(); err != nil { +func (t *Task) RunCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { + if err := t.state.Definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - p.logger.Error("failed to get docker client", zap.Error(err), zap.String("taskName", taskName)) - return "", "", 0, err - } + t.stateMu.Lock() + defer t.stateMu.Unlock() - definition.Entrypoint = []string{"sh", "-c"} - definition.Command = []string{"sleep 36000"} - definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", definition.Name, util.RandomString(5), time.Now().Unix()) - definition.Ports = []string{} + t.state.Definition.Entrypoint = []string{"sh", "-c"} + t.state.Definition.Command = []string{"sleep 36000"} + t.state.Definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", t.state.Definition.Name, util.RandomString(5), time.Now().Unix()) + t.state.Definition.Ports = []string{} - createdContainer, err := dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, + createdContainer, err := t.dockerClient.ContainerCreate(ctx, &container.Config{ + Image: t.state.Definition.Image.Image, + Entrypoint: t.state.Definition.Entrypoint, + Cmd: t.state.Definition.Command, Tty: false, - Hostname: definition.Name, + Hostname: t.state.Definition.Name, Labels: map[string]string{ - providerLabelName: p.name, + providerLabelName: t.state.ProviderName, }, - Env: convertEnvMapToList(definition.Environment), + Env: convertEnvMapToList(t.state.Definition.Environment), }, &container.HostConfig{ Mounts: []mount.Mount{ { Type: mount.TypeBind, Source: "/docker_volumes", - Target: definition.DataDir, + Target: t.state.Definition.DataDir, }, }, NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) + }, nil, nil, t.state.Definition.ContainerName) if err != nil { - p.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } defer func() { - if _, err := dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { - // auto-removed, but not detected as autoremoved + if _, err := t.dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { + // container was auto-removed, no need to remove it again return } - if err := dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { - p.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", taskName), zap.String("id", createdContainer.ID)) + if err := t.dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { + t.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", t.state.Name), zap.String("id", createdContainer.ID)) } }() - if err := startContainerWithBlock(ctx, dockerClient, createdContainer.ID); err != nil { - p.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", taskName)) + if err := startContainerWithBlock(ctx, t.dockerClient, createdContainer.ID); err != nil { + t.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } // wait for container start - exec, err := dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ AttachStdout: true, AttachStderr: true, - Cmd: command, + Cmd: cmd, }) if err != nil { - p.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { - p.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } @@ -476,7 +421,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -500,7 +445,7 @@ loop: return stdout.String(), stderr.String(), lastExitCode, err } -func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Client, containerID string) error { +func startContainerWithBlock(ctx context.Context, dockerClient DockerClient, containerID string) error { // start container if err := dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { return err @@ -532,18 +477,18 @@ func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Cli } } -func (p *Provider) pullImage(ctx context.Context, dockerClient *dockerclient.Client, img string) error { - p.logger.Info("pulling image", zap.String("image", img)) +func pullImage(ctx context.Context, dockerClient DockerClient, logger *zap.Logger, img string) error { + logger.Info("pulling image", zap.String("image", img)) resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) if err != nil { - return err + return errors.Wrap(err, "failed to pull docker image") } defer resp.Close() // throw away the image pull stdout response _, err = io.Copy(io.Discard, resp) if err != nil { - return err + return errors.Wrap(err, "failed to pull docker image") } return nil } diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go new file mode 100644 index 00000000..9d0b94fc --- /dev/null +++ b/core/provider/digitalocean/task_test.go @@ -0,0 +1,828 @@ +package digitalocean + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "net/http" + "testing" + "time" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// mockConn implements net.Conn interface for testing +type mockConn struct { + *bytes.Buffer +} + +// mockNotFoundError implements errdefs.ErrNotFound interface for testing +type mockNotFoundError struct { + error +} + +func (e mockNotFoundError) NotFound() {} + +func (m mockConn) Close() error { return nil } +func (m mockConn) LocalAddr() net.Addr { return nil } +func (m mockConn) RemoteAddr() net.Addr { return nil } +func (m mockConn) SetDeadline(t time.Time) error { return nil } +func (m mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestTaskLifecycle(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ + ID: "test-container-id", + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockDocker.On("ContainerStop", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + err = task.Stop(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_STOPPED, task.GetState().Status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskRunCommand(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + testContainer := types.Container{ + ID: "test-testContainer-id", + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{testContainer}, nil) + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskRunCommandWhileStopped(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil).Once() + + mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskGetIP(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + expectedIP := "1.2.3.4" + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: expectedIP, + }, + }, + }, + } + + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + ip, err := task.GetIP(ctx) + require.NoError(t, err) + require.Equal(t, expectedIP, ip) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s:80", expectedIP), externalAddr) + + mockDO.AssertExpectations(t) +} + +func TestTaskDestroy(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("DeleteDropletByID", mock.Anything, droplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("Close").Return(nil) + + provider := &Provider{ + state: &ProviderState{ + TaskStates: make(map[int]*TaskState), + }, + } + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + provider: provider, + } + + provider.state.TaskStates[task.state.ID] = task.state + + err := task.Destroy(ctx) + require.NoError(t, err) + require.Empty(t, provider.state.TaskStates) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container exists + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, nil).Once() + + // container should be removed since it exists + mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +// this tests the case where the docker container is auto removed before cleanup doesnt return an error +func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container not found, so ContainerRemove should not be called + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, mockNotFoundError{fmt.Errorf("Error: No such container: test-container-id")}).Once() + mockDocker.AssertNotCalled(t, "ContainerRemove", mock.Anything, mock.Anything, mock.Anything) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskExposingPort(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ + ID: "test-container-id", + Ports: []types.Port{ + { + PrivatePort: 80, + PublicPort: 80, + Type: "tcp", + }, + }, + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, "1.2.3.4:80", externalAddr) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", externalAddr), nil) + require.NoError(t, err) + require.NotEmpty(t, req) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestGetStatus(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + tests := []struct { + name string + dropletStatus string + containerState string + setupMocks func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) + expectedStatus provider.TaskStatus + expectError bool + }{ + { + name: "droplet not active", + dropletStatus: "off", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "off", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container running", + dropletStatus: "active", + containerState: "running", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_RUNNING, + expectError: false, + }, + { + name: "container paused", + dropletStatus: "active", + containerState: "paused", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "paused", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_PAUSED, + expectError: false, + }, + { + name: "container stopped state", + dropletStatus: "active", + containerState: "exited", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "no containers found", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{}, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container inspect error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{}, fmt.Errorf("inspect error")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container removing", + dropletStatus: "active", + containerState: "removing", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "removing", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container dead", + dropletStatus: "active", + containerState: "dead", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "dead", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container created", + dropletStatus: "active", + containerState: "created", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "created", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "unknown container status", + dropletStatus: "active", + containerState: "unknown_status", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "unknown_status", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: false, + }, + { + name: "getDroplet error", + dropletStatus: "", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("failed to get droplet")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "containerList error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed to list containers")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + tt.setupMocks(mockDocker, mockDO) + + task := &Task{ + state: &TaskState{ + ID: 123, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + status, err := task.GetStatus(ctx) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expectedStatus, status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) + }) + } +}