From baee8f5db205cdbddb3315305a351e7dcb5a5d1b Mon Sep 17 00:00:00 2001 From: Zygimantas <zygis@skip.build> Date: Sat, 4 Jan 2025 01:06:02 +0200 Subject: [PATCH 01/50] feat: rewrite the docker provider --- core/provider/docker/network.go | 2 ++ core/provider/docker/provider.go | 1 + core/provider/docker/provider_test.go | 4 ++++ core/provider/docker/task.go | 10 ++++++++++ 4 files changed, 17 insertions(+) diff --git a/core/provider/docker/network.go b/core/provider/docker/network.go index f18d42cf..8392380a 100644 --- a/core/provider/docker/network.go +++ b/core/provider/docker/network.go @@ -141,6 +141,8 @@ func (p *Provider) nextAvailableIP() (string, error) { return "", err } + p.state.AllocatedIPs = append(p.state.AllocatedIPs, ip.To4().String()) + return ip.String(), nil } diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index f0636305..80bba660 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -224,6 +224,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin logger.Debug("creating container", zap.String("name", definition.Name), zap.String("image", definition.Image.Image)) + // network map is volatile, so we need to mutex update it ip, err := p.nextAvailableIP() if err != nil { return nil, err diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index e381a57f..837182d6 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -17,6 +17,7 @@ import ( "github.com/skip-mev/petri/core/v3/provider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" ) const idAlphabet = "abcdefghijklqmnoqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" @@ -107,6 +108,7 @@ func TestCreateTask(t *testing.T) { p, err := docker.CreateProvider(context.Background(), logger, providerName) require.NoError(t, err) + defer p.Teardown(context.Background()) defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) @@ -183,6 +185,7 @@ func TestConcurrentTaskCreation(t *testing.T) { p, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) + defer p.Teardown(ctx) defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) @@ -250,6 +253,7 @@ func TestProviderSerialization(t *testing.T) { p1, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) + defer p1.Teardown(ctx) defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index ea569f8c..f1e30b0d 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -98,6 +98,10 @@ func (t *Task) Destroy(ctx context.Context) error { return nil } +func (t *Task) ensure(_ context.Context) error { + return nil +} + func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { t.provider.logger.Debug("getting external address", zap.String("id", t.state.Id)) @@ -154,6 +158,8 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { return provider.TASK_STATUS_UNDEFINED, err } + fmt.Println(containerJSON.State.Status) + switch state := containerJSON.State.Status; state { case "created": return provider.TASK_STOPPED, nil @@ -174,6 +180,10 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { return provider.TASK_STATUS_UNDEFINED, nil } +func (t *Task) Initialize(ctx context.Context) error { + return nil +} + func (t *Task) Modify(ctx context.Context, td provider.TaskDefinition) error { panic("unimplemented") } From 753c6e35c6ea25820f9fdba9fb80a47ba6a32fc6 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Sun, 12 Jan 2025 16:52:46 +0200 Subject: [PATCH 02/50] wip --- core/provider/digitalocean/client.go | 86 ++ .../digitalocean/digitalocean_provider.go | 165 ---- core/provider/digitalocean/docker.go | 94 ++ core/provider/digitalocean/droplet.go | 67 +- core/provider/digitalocean/firewall.go | 2 +- .../digitalocean/mocks/do_client_mock.go | 414 +++++++++ .../digitalocean/mocks/docker_client_mock.go | 377 ++++++++ core/provider/digitalocean/provider.go | 328 +++++++ core/provider/digitalocean/provider_test.go | 552 ++++++++++++ core/provider/digitalocean/ssh.go | 2 +- core/provider/digitalocean/tag.go | 2 +- core/provider/digitalocean/task.go | 285 +++--- core/provider/digitalocean/task_test.go | 828 ++++++++++++++++++ 13 files changed, 2820 insertions(+), 382 deletions(-) create mode 100644 core/provider/digitalocean/client.go delete mode 100644 core/provider/digitalocean/digitalocean_provider.go create mode 100644 core/provider/digitalocean/docker.go create mode 100644 core/provider/digitalocean/mocks/do_client_mock.go create mode 100644 core/provider/digitalocean/mocks/docker_client_mock.go create mode 100644 core/provider/digitalocean/provider.go create mode 100644 core/provider/digitalocean/provider_test.go create mode 100644 core/provider/digitalocean/task_test.go diff --git a/core/provider/digitalocean/client.go b/core/provider/digitalocean/client.go new file mode 100644 index 00000000..753abc6f --- /dev/null +++ b/core/provider/digitalocean/client.go @@ -0,0 +1,86 @@ +package digitalocean + +import ( + "context" + + "github.com/digitalocean/godo" +) + +// DoClient defines the interface for DigitalOcean API operations used by the provider +type DoClient interface { + // Droplet operations + CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) + GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) + DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) + DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) + + // Firewall operations + CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) + DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) + + // SSH Key operations + CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) + DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) + GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) + + // Tag operations + CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) + DeleteTag(ctx context.Context, tag string) (*godo.Response, error) +} + +// godoClient implements the DoClient interface using the actual godo.Client +type godoClient struct { + *godo.Client +} + +func NewGodoClient(token string) DoClient { + return &godoClient{Client: godo.NewFromToken(token)} +} + +// Droplet operations +func (c *godoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { + return c.Droplets.Create(ctx, req) +} + +func (c *godoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { + return c.Droplets.Get(ctx, dropletID) +} + +func (c *godoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { + return c.Droplets.DeleteByTag(ctx, tag) +} + +func (c *godoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { + return c.Droplets.Delete(ctx, id) +} + +// Firewall operations +func (c *godoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { + return c.Firewalls.Create(ctx, req) +} + +func (c *godoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { + return c.Firewalls.Delete(ctx, firewallID) +} + +// SSH Key operations +func (c *godoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { + return c.Keys.Create(ctx, req) +} + +func (c *godoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { + return c.Keys.DeleteByFingerprint(ctx, fingerprint) +} + +func (c *godoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { + return c.Keys.GetByFingerprint(ctx, fingerprint) +} + +// Tag operations +func (c *godoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + return c.Tags.Create(ctx, req) +} + +func (c *godoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { + return c.Tags.Delete(ctx, tag) +} diff --git a/core/provider/digitalocean/digitalocean_provider.go b/core/provider/digitalocean/digitalocean_provider.go deleted file mode 100644 index 887244cc..00000000 --- a/core/provider/digitalocean/digitalocean_provider.go +++ /dev/null @@ -1,165 +0,0 @@ -package digitalocean - -import ( - "context" - "fmt" - "strings" - - "github.com/digitalocean/godo" - "go.uber.org/zap" - - xsync "github.com/puzpuzpuz/xsync/v3" - - "github.com/skip-mev/petri/core/v3/provider" - "github.com/skip-mev/petri/core/v3/util" - "golang.org/x/crypto/ssh" -) - -var _ provider.Provider = (*Provider)(nil) - -const ( - providerLabelName = "petri-provider" -) - -type Provider struct { - logger *zap.Logger - name string - doClient *godo.Client - petriTag string - - userIPs []string - - sshKeyPair *SSHKeyPair - - sshClients *xsync.MapOf[string, *ssh.Client] - - firewallID string -} - -// NewDigitalOceanProvider creates a provider that implements the Provider interface for DigitalOcean. -// Token is the DigitalOcean API token -func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { - doClient := godo.NewFromToken(token) - - if sshKeyPair == nil { - newSshKeyPair, err := MakeSSHKeyPair() - if err != nil { - return nil, err - } - sshKeyPair = newSshKeyPair - } - - userIPs, err := getUserIPs(ctx) - if err != nil { - return nil, err - } - - userIPs = append(userIPs, additionalUserIPS...) - - digitalOceanProvider := &Provider{ - logger: logger.Named("digitalocean_provider"), - name: providerName, - doClient: doClient, - petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), - - userIPs: userIPs, - - sshClients: xsync.NewMapOf[string, *ssh.Client](), - sshKeyPair: sshKeyPair, - } - - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) - if err != nil { - return nil, err - } - - firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) - if err != nil { - return nil, fmt.Errorf("failed to create firewall: %w", err) - } - - digitalOceanProvider.firewallID = firewall.ID - - //TODO(Zygimantass): TOCTOU issue - if key, _, err := doClient.Keys.GetByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { - _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) - if err != nil { - if !strings.Contains(err.Error(), "422") { - return nil, err - } - } - } - - return digitalOceanProvider, nil -} - -func (p *Provider) Teardown(ctx context.Context) error { - p.logger.Info("tearing down DigitalOcean provider") - - if err := p.teardownTasks(ctx); err != nil { - return err - } - if err := p.teardownFirewall(ctx); err != nil { - return err - } - if err := p.teardownSSHKey(ctx); err != nil { - return err - } - if err := p.teardownTag(ctx); err != nil { - return err - } - - return nil -} - -func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.Droplets.DeleteByTag(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.Firewalls.Delete(ctx, p.firewallID) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.Keys.DeleteByFingerprint(ctx, p.sshKeyPair.Fingerprint) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.Tags.Delete(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} diff --git a/core/provider/digitalocean/docker.go b/core/provider/digitalocean/docker.go new file mode 100644 index 00000000..10a6597a --- /dev/null +++ b/core/provider/digitalocean/docker.go @@ -0,0 +1,94 @@ +package digitalocean + +import ( + "context" + "io" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/network" + dockerclient "github.com/docker/docker/client" + specs "github.com/opencontainers/image-spec/specs-go/v1" +) + +// DockerClient is an interface that abstracts Docker functionality +type DockerClient interface { + Ping(ctx context.Context) (types.Ping, error) + ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) + ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) + ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) + ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error + ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error + ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) + ContainerExecCreate(ctx context.Context, container string, config types.ExecConfig) (types.IDResponse, error) + ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) + ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) + ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error + Close() error +} + +type defaultDockerClient struct { + client *dockerclient.Client +} + +func NewDockerClient(host string) (DockerClient, error) { + client, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(host)) + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil +} + +func (d *defaultDockerClient) Ping(ctx context.Context) (types.Ping, error) { + return d.client.Ping(ctx) +} + +func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) { + return d.client.ImageInspectWithRaw(ctx, image) +} + +func (d *defaultDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + return d.client.ImagePull(ctx, ref, options) +} + +func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { + return d.client.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, containerName) +} + +func (d *defaultDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + return d.client.ContainerList(ctx, options) +} + +func (d *defaultDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + return d.client.ContainerStart(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + return d.client.ContainerStop(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + return d.client.ContainerInspect(ctx, containerID) +} + +func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, options container.ExecOptions) (types.IDResponse, error) { + return d.client.ContainerExecCreate(ctx, container, options) +} + +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecAttachOptions) (types.HijackedResponse, error) { + return d.client.ContainerExecAttach(ctx, execID, config) +} + +func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { + return d.client.ContainerExecInspect(ctx, execID) +} + +func (d *defaultDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + return d.client.ContainerRemove(ctx, containerID, options) +} + +func (d *defaultDockerClient) Close() error { + return d.client.Close() +} diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 8166baa1..b9c9423d 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -8,7 +8,6 @@ import ( "github.com/pkg/errors" "github.com/digitalocean/godo" - dockerclient "github.com/docker/docker/client" "go.uber.org/zap" "golang.org/x/crypto/ssh" @@ -57,7 +56,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe Tags: []string{p.petriTag}, } - droplet, res, err := p.doClient.Droplets.Create(ctx, req) + droplet, res, err := p.doClient.CreateDroplet(ctx, req) if err != nil { return nil, err } @@ -69,7 +68,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe start := time.Now() err = util.WaitForCondition(ctx, time.Second*600, time.Millisecond*300, func() (bool, error) { - d, _, err := p.doClient.Droplets.Get(ctx, droplet.ID) + d, _, err := p.doClient.GetDroplet(ctx, droplet.ID) if err != nil { return false, err } @@ -83,13 +82,16 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return false, nil } - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) - if err != nil { - p.logger.Error("failed to create docker client", zap.Error(err)) - return false, err + if p.dockerClients[ip] == nil { + dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + p.logger.Error("failed to create docker client", zap.Error(err)) + return false, err + } + p.dockerClients[ip] = dockerClient } - _, err = dockerClient.Ping(ctx) + _, err = p.dockerClients[ip].Ping(ctx) if err != nil { return false, nil } @@ -109,14 +111,14 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return droplet, nil } -func (p *Provider) deleteDroplet(ctx context.Context, name string) error { - droplet, err := p.getDroplet(ctx, name) +func (t *Task) deleteDroplet(ctx context.Context) error { + droplet, err := t.getDroplet(ctx) if err != nil { return err } - res, err := p.doClient.Droplets.Delete(ctx, droplet.ID) + res, err := t.doClient.DeleteDropletByID(ctx, droplet.ID) if err != nil { return err } @@ -128,11 +130,8 @@ func (p *Provider) deleteDroplet(ctx context.Context, name string) error { return nil } -func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, error) { - // TODO(Zygimantass): this change assumes that all Petri droplets are unique by name - // which should be technically true, but there might be edge cases where it's not. - droplets, res, err := p.doClient.Droplets.ListByName(ctx, name, nil) - +func (t *Task) getDroplet(ctx context.Context) (*godo.Droplet, error) { + droplet, res, err := t.doClient.GetDroplet(ctx, t.state.ID) if err != nil { return nil, err } @@ -141,46 +140,28 @@ func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) } - if len(droplets) == 0 { - return nil, fmt.Errorf("could not find droplet") - } - - return &droplets[0], nil -} - -func (p *Provider) getDropletDockerClient(ctx context.Context, taskName string) (*dockerclient.Client, error) { - ip, err := p.GetIP(ctx, taskName) - if err != nil { - return nil, err - } - - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) - if err != nil { - return nil, err - } - - return dockerClient, nil + return droplet, nil } -func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { - if _, err := p.getDroplet(ctx, taskName); err != nil { +func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { + if _, err := t.getDroplet(ctx); err != nil { return nil, fmt.Errorf("droplet %s does not exist", taskName) } - if client, ok := p.sshClients.Load(taskName); ok { - status, _, err := client.SendRequest("ping", true, []byte("ping")) + if t.sshClient != nil { + status, _, err := t.sshClient.SendRequest("ping", true, []byte("ping")) if err == nil && status { - return client, nil + return t.sshClient, nil } } - ip, err := p.GetIP(ctx, taskName) + ip, err := t.GetIP(ctx) if err != nil { return nil, err } - parsedSSHKey, err := ssh.ParsePrivateKey([]byte(p.sshKeyPair.PrivateKey)) + parsedSSHKey, err := ssh.ParsePrivateKey([]byte(t.sshKeyPair.PrivateKey)) if err != nil { return nil, err } @@ -202,7 +183,5 @@ func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*s return nil, err } - p.sshClients.Store(taskName, client) - return client, nil } diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index 82917b5a..f8c73d3c 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -54,7 +54,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go }, } - firewall, res, err := p.doClient.Firewalls.Create(ctx, req) + firewall, res, err := p.doClient.CreateFirewall(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/mocks/do_client_mock.go b/core/provider/digitalocean/mocks/do_client_mock.go new file mode 100644 index 00000000..a94deb7c --- /dev/null +++ b/core/provider/digitalocean/mocks/do_client_mock.go @@ -0,0 +1,414 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + godo "github.com/digitalocean/godo" + + mock "github.com/stretchr/testify/mock" +) + +// DoClient is an autogenerated mock type for the DoClient type +type DoClient struct { + mock.Mock +} + +// CreateDroplet provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDroplet") + } + + var r0 *godo.Droplet + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) *godo.Droplet); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.DropletCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.DropletCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateFirewall provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateFirewall") + } + + var r0 *godo.Firewall + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) *godo.Firewall); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Firewall) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.FirewallRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.FirewallRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateKey provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateKey") + } + + var r0 *godo.Key + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) *godo.Key); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.KeyCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.KeyCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CreateTag provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateTag") + } + + var r0 *godo.Tag + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) *godo.Tag); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.TagCreateRequest) *godo.Response); ok { + r1 = rf(ctx, req) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, *godo.TagCreateRequest) error); ok { + r2 = rf(ctx, req) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// DeleteDropletByID provides a mock function with given fields: ctx, id +func (_m *DoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByID") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Response, error)); ok { + return rf(ctx, id) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Response); ok { + r0 = rf(ctx, id) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, id) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteDropletByTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByTag") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, tag) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteFirewall provides a mock function with given fields: ctx, firewallID +func (_m *DoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { + ret := _m.Called(ctx, firewallID) + + if len(ret) == 0 { + panic("no return value specified for DeleteFirewall") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, firewallID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, firewallID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, firewallID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for DeleteKeyByFingerprint") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, fingerprint) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, fingerprint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, fingerprint) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + + var r0 *godo.Response + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { + return rf(ctx, tag) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + r0 = rf(ctx, tag) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Response) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, tag) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetDroplet provides a mock function with given fields: ctx, dropletID +func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { + ret := _m.Called(ctx, dropletID) + + if len(ret) == 0 { + panic("no return value specified for GetDroplet") + } + + var r0 *godo.Droplet + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Droplet, *godo.Response, error)); ok { + return rf(ctx, dropletID) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Droplet); ok { + r0 = rf(ctx, dropletID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) *godo.Response); ok { + r1 = rf(ctx, dropletID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, int) error); ok { + r2 = rf(ctx, dropletID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// GetKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for GetKeyByFingerprint") + } + + var r0 *godo.Key + var r1 *godo.Response + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Key, *godo.Response, error)); ok { + return rf(ctx, fingerprint) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Key); ok { + r0 = rf(ctx, fingerprint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) *godo.Response); ok { + r1 = rf(ctx, fingerprint) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*godo.Response) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, fingerprint) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// NewDoClient creates a new instance of DoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDoClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DoClient { + mock := &DoClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/digitalocean/mocks/docker_client_mock.go b/core/provider/digitalocean/mocks/docker_client_mock.go new file mode 100644 index 00000000..e27c4705 --- /dev/null +++ b/core/provider/digitalocean/mocks/docker_client_mock.go @@ -0,0 +1,377 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + container "github.com/docker/docker/api/types/container" + + image "github.com/docker/docker/api/types/image" + + io "io" + + mock "github.com/stretchr/testify/mock" + + network "github.com/docker/docker/api/types/network" + + types "github.com/docker/docker/api/types" + + v1 "github.com/opencontainers/image-spec/specs-go/v1" +) + +// DockerClient is an autogenerated mock type for the DockerClient type +type DockerClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DockerClient) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerCreate provides a mock function with given fields: ctx, config, hostConfig, networkingConfig, platform, containerName +func (_m *DockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *v1.Platform, containerName string) (container.CreateResponse, error) { + ret := _m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) + + if len(ret) == 0 { + panic("no return value specified for ContainerCreate") + } + + var r0 container.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) (container.CreateResponse, error)); ok { + return rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) container.CreateResponse); ok { + r0 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r0 = ret.Get(0).(container.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) error); ok { + r1 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecAttach provides a mock function with given fields: ctx, execID, config +func (_m *DockerClient) ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) { + ret := _m.Called(ctx, execID, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecAttach") + } + + var r0 types.HijackedResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) (types.HijackedResponse, error)); ok { + return rf(ctx, execID, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) types.HijackedResponse); ok { + r0 = rf(ctx, execID, config) + } else { + r0 = ret.Get(0).(types.HijackedResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecStartCheck) error); ok { + r1 = rf(ctx, execID, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecCreate provides a mock function with given fields: ctx, _a1, config +func (_m *DockerClient) ContainerExecCreate(ctx context.Context, _a1 string, config types.ExecConfig) (types.IDResponse, error) { + ret := _m.Called(ctx, _a1, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecCreate") + } + + var r0 types.IDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) (types.IDResponse, error)); ok { + return rf(ctx, _a1, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) types.IDResponse); ok { + r0 = rf(ctx, _a1, config) + } else { + r0 = ret.Get(0).(types.IDResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecConfig) error); ok { + r1 = rf(ctx, _a1, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecInspect provides a mock function with given fields: ctx, execID +func (_m *DockerClient) ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) { + ret := _m.Called(ctx, execID) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecInspect") + } + + var r0 types.ContainerExecInspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerExecInspect, error)); ok { + return rf(ctx, execID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerExecInspect); ok { + r0 = rf(ctx, execID) + } else { + r0 = ret.Get(0).(types.ContainerExecInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, execID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerInspect provides a mock function with given fields: ctx, containerID +func (_m *DockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + ret := _m.Called(ctx, containerID) + + if len(ret) == 0 { + panic("no return value specified for ContainerInspect") + } + + var r0 types.ContainerJSON + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerJSON, error)); ok { + return rf(ctx, containerID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerJSON); ok { + r0 = rf(ctx, containerID) + } else { + r0 = ret.Get(0).(types.ContainerJSON) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, containerID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerList provides a mock function with given fields: ctx, options +func (_m *DockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerList") + } + + var r0 []types.Container + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) ([]types.Container, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) []types.Container); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Container) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, container.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerRemove provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.RemoveOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStart provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StartOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStop provides a mock function with given fields: ctx, containerID, options +func (_m *DockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + ret := _m.Called(ctx, containerID, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStop") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StopOptions) error); ok { + r0 = rf(ctx, containerID, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ImageInspectWithRaw provides a mock function with given fields: ctx, _a1 +func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, _a1 string) (types.ImageInspect, []byte, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for ImageInspectWithRaw") + } + + var r0 types.ImageInspect + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ImageInspect, []byte, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ImageInspect); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(types.ImageInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) []byte); ok { + r1 = rf(ctx, _a1) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, _a1) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ImagePull provides a mock function with given fields: ctx, ref, options +func (_m *DockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + ret := _m.Called(ctx, ref, options) + + if len(ret) == 0 { + panic("no return value specified for ImagePull") + } + + var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) (io.ReadCloser, error)); ok { + return rf(ctx, ref, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) io.ReadCloser); ok { + r0 = rf(ctx, ref, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, image.PullOptions) error); ok { + r1 = rf(ctx, ref, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Ping provides a mock function with given fields: ctx +func (_m *DockerClient) Ping(ctx context.Context) (types.Ping, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 types.Ping + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (types.Ping, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) types.Ping); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(types.Ping) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewDockerClient creates a new instance of DockerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDockerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DockerClient { + mock := &DockerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go new file mode 100644 index 00000000..713f465e --- /dev/null +++ b/core/provider/digitalocean/provider.go @@ -0,0 +1,328 @@ +package digitalocean + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "sync" + + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/util" +) + +var _ provider.ProviderI = (*Provider)(nil) + +const ( + providerLabelName = "petri-provider" + sshPort = "2375" +) + +type ProviderState struct { + TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state + Name string `json:"name"` +} + +type Provider struct { + state *ProviderState + stateMu sync.Mutex + + logger *zap.Logger + name string + doClient DoClient + petriTag string + userIPs []string + sshKeyPair *SSHKeyPair + firewallID string + dockerClients map[string]DockerClient // map of droplet ip address to docker clients +} + +// NewProvider creates a provider that implements the Provider interface for DigitalOcean. +// Token is the DigitalOcean API token +func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + doClient := NewGodoClient(token) + return NewProviderWithClient(ctx, logger, providerName, doClient, nil, additionalUserIPS, sshKeyPair) +} + +// NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. +// This is primarily used for testing. +func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + if sshKeyPair == nil { + newSshKeyPair, err := MakeSSHKeyPair() + if err != nil { + return nil, err + } + sshKeyPair = newSshKeyPair + } + + userIPs, err := getUserIPs(ctx) + if err != nil { + return nil, err + } + + userIPs = append(userIPs, additionalUserIPS...) + + digitalOceanProvider := &Provider{ + logger: logger.Named("digitalocean_provider"), + name: providerName, + doClient: doClient, + petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), + userIPs: userIPs, + sshKeyPair: sshKeyPair, + dockerClients: dockerClients, + state: &ProviderState{TaskStates: make(map[int]*TaskState)}, + } + + _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) + if err != nil { + return nil, err + } + + firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) + if err != nil { + return nil, fmt.Errorf("failed to create firewall: %w", err) + } + + digitalOceanProvider.firewallID = firewall.ID + + //TODO(Zygimantass): TOCTOU issue + if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { + _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) + if err != nil { + if !strings.Contains(err.Error(), "422") { + return nil, err + } + } + } + + return digitalOceanProvider, nil +} + +func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefinition) (provider.TaskI, error) { + if err := definition.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate task definition: %w", err) + } + + if definition.ProviderSpecificConfig == nil { + return nil, fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) + } + + var doConfig DigitalOceanTaskConfig + doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) + if !ok { + return nil, fmt.Errorf("invalid provider specific config type for %s", definition.Name) + } + + if err := doConfig.ValidateBasic(); err != nil { + return nil, fmt.Errorf("could not cast digitalocean specific config: %w", err) + } + + p.logger.Info("creating droplet", zap.String("name", definition.Name)) + + droplet, err := p.CreateDroplet(ctx, definition) + if err != nil { + return nil, err + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return nil, err + } + + p.logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) + + if p.dockerClients[ip] == nil { + p.dockerClients[ip], err = NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + return nil, err + } + } + + _, _, err = p.dockerClients[ip].ImageInspectWithRaw(ctx, definition.Image.Image) + if err != nil { + p.logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) + err = pullImage(ctx, p.dockerClients[ip], p.logger, definition.Image.Image) + if err != nil { + return nil, err + } + } + + _, err = p.dockerClients[ip].ContainerCreate(ctx, &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Labels: map[string]string{ + providerLabelName: p.name, + }, + Env: convertEnvMapToList(definition.Environment), + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: definition.DataDir, + }, + }, + NetworkMode: container.NetworkMode("host"), + }, nil, nil, definition.ContainerName) + if err != nil { + return nil, err + } + + taskState := &TaskState{ + ID: droplet.ID, + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + ProviderName: p.name, + } + + p.stateMu.Lock() + defer p.stateMu.Unlock() + + p.state.TaskStates[taskState.ID] = taskState + + return &Task{ + state: taskState, + provider: p, + sshKeyPair: p.sshKeyPair, + logger: p.logger.With(zap.String("task", definition.Name)), + doClient: p.doClient, + dockerClient: p.dockerClients[ip], + }, nil +} + +func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + bz, err := json.Marshal(p.state) + + return bz, err +} + +func (p *Provider) DeserializeProvider(context.Context) ([]byte, error) { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + bz, err := json.Marshal(p.state) + + return bz, err +} + +func (p *Provider) SerializeTask(ctx context.Context, task provider.TaskI) ([]byte, error) { + if _, ok := task.(*Task); !ok { + return nil, fmt.Errorf("task is not a Docker task") + } + + dockerTask := task.(*Task) + + bz, err := json.Marshal(dockerTask.state) + + if err != nil { + return nil, err + } + + return bz, nil +} + +func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.TaskI, error) { + var taskState TaskState + + err := json.Unmarshal(bz, &taskState) + if err != nil { + return nil, err + } + + task := &Task{ + state: &taskState, + } + + return task, nil +} + +func (p *Provider) Teardown(ctx context.Context) error { + p.logger.Info("tearing down DigitalOcean provider") + + if err := p.teardownTasks(ctx); err != nil { + return err + } + if err := p.teardownFirewall(ctx); err != nil { + return err + } + if err := p.teardownSSHKey(ctx); err != nil { + return err + } + if err := p.teardownTag(ctx); err != nil { + return err + } + return nil +} + +func (p *Provider) teardownTasks(ctx context.Context) error { + res, err := p.doClient.DeleteDropletByTag(ctx, p.petriTag) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownFirewall(ctx context.Context) error { + res, err := p.doClient.DeleteFirewall(ctx, p.firewallID) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownSSHKey(ctx context.Context) error { + res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.sshKeyPair.Fingerprint) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) teardownTag(ctx context.Context) error { + res, err := p.doClient.DeleteTag(ctx, p.petriTag) + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +func (p *Provider) removeTask(_ context.Context, taskID int) error { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + delete(p.state.TaskStates, taskID) + + return nil +} diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go new file mode 100644 index 00000000..af84b079 --- /dev/null +++ b/core/provider/digitalocean/provider_test.go @@ -0,0 +1,552 @@ +package digitalocean + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/skip-mev/petri/core/v2/util" +) + +func setupMockClient(t *testing.T) *mocks.DoClient { + return mocks.NewDoClient(t) +} + +func setupMockDockerClient(t *testing.T) *mocks.DockerClient { + return mocks.NewDockerClient(t) +} + +func TestNewProvider(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + + tests := []struct { + name string + token string + additionalIPs []string + sshKeyPair *SSHKeyPair + expectedError bool + mockSetup func(*mocks.DoClient) + }{ + { + name: "valid provider creation", + token: "test-token", + additionalIPs: []string{"1.2.3.4"}, + sshKeyPair: nil, + expectedError: false, + mockSetup: func(m *mocks.DoClient) { + m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + m.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + m.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + m.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + }, + }, + { + name: "bad token", + token: "foobar", + additionalIPs: []string{}, + sshKeyPair: nil, + expectedError: true, + mockSetup: func(m *mocks.DoClient) { + m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusUnauthorized}}, fmt.Errorf("unauthorized")) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + tc.mockSetup(mockClient) + mockDockerClients := map[string]DockerClient{ + "test-ip": mockDocker, + } + + provider, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, tc.additionalIPs, tc.sshKeyPair) + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, provider) + } else { + assert.NoError(t, err) + assert.NotNil(t, provider) + assert.NotEmpty(t, provider.petriTag) + assert.NotEmpty(t, provider.firewallID) + } + }) + } +} + +func TestCreateTask(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDockerClients := map[string]DockerClient{ + "10.0.0.1": mockDocker, + } + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + tests := []struct { + name string + taskDef provider.TaskDefinition + expectedError bool + }{ + { + name: "valid task creation", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + }, + expectedError: false, + }, + { + name: "missing provider specific config", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: nil, + }, + expectedError: true, + }, + { + name: "invalid provider specific config - missing region", + taskDef: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "image_id": "123456", + }, + }, + expectedError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + task, err := p.CreateTask(ctx, tc.taskDef) + if tc.expectedError { + assert.Error(t, err) + assert.Nil(t, task) + } else { + assert.NoError(t, err) + assert.NotNil(t, task) + } + }) + } +} + +func TestSerializeAndRestore(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + mockDockerClients := map[string]DockerClient{ + "10.0.0.1": mockDocker, + } + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). + Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + providerData, err := p.SerializeProvider(ctx) + assert.NoError(t, err) + assert.NotNil(t, providerData) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + require.NoError(t, err) + + taskData, err := p.SerializeTask(ctx, task) + assert.NoError(t, err) + assert.NotNil(t, taskData) + + deserializedTask, err := p.DeserializeTask(ctx, taskData) + assert.NoError(t, err) + assert.NotNil(t, deserializedTask) +} + +func TestTeardown(t *testing.T) { + logger := zap.NewExample() + ctx := context.Background() + mockClient := setupMockClient(t) + mockDocker := setupMockDockerClient(t) + + mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockClient.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockClient.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, nil, []string{}, nil) + require.NoError(t, err) + + err = p.Teardown(ctx) + assert.NoError(t, err) + mockClient.AssertExpectations(t) + mockDocker.AssertExpectations(t) +} + +func TestConcurrentTaskCreationAndCleanup(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + logger, _ := zap.NewDevelopment() + mockDockerClients := make(map[string]DockerClient) + mockDO := mocks.NewDoClient(t) + + for i := 0; i < 10; i++ { + ip := fmt.Sprintf("10.0.0.%d", i+1) + mockDocker := mocks.NewDockerClient(t) + mockDockerClients[ip] = mockDocker + + mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil).Once() + mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil).Once() + mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), + mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), + mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{ + { + ID: fmt.Sprintf("container-%d", i), + State: "running", + }, + }, nil).Times(3) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil).Maybe() + mockDocker.On("Close").Return(nil).Once() + } + + mockDO.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). + Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). + Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockDO.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). + Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + numTasks := 10 + var wg sync.WaitGroup + errors := make(chan error, numTasks) + tasks := make(chan *Task, numTasks) + taskMutex := sync.Mutex{} + dropletIDs := make(map[int]bool) + ipAddresses := make(map[string]bool) + + for i := 0; i < numTasks; i++ { + dropletID := 1000 + i + droplet := &godo.Droplet{ + ID: dropletID, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: fmt.Sprintf("10.0.0.%d", i+1), + }, + }, + }, + } + + mockDO.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusCreated}}, nil).Once() + // we cant predict how many times GetDroplet will be called exactly as the provider polls waiting for its creation + mockDO.On("GetDroplet", mock.Anything, dropletID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO.On("DeleteDropletByID", mock.Anything, dropletID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() + } + + // these are called once per provider, not per task + mockDO.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). + Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + + for i := 0; i < numTasks; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + task, err := p.CreateTask(ctx, provider.TaskDefinition{ + Name: fmt.Sprintf("test-task-%d", index), + ContainerName: fmt.Sprintf("test-container-%d", index), + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456789", // Using a string for image_id + }, + }) + + if err != nil { + errors <- fmt.Errorf("task creation error: %v", err) + return + } + + if err := task.Start(ctx); err != nil { + errors <- fmt.Errorf("task start error: %v", err) + return + } + + // thread-safe recording of task details + taskMutex.Lock() + doTask := task.(*Task) + state := doTask.GetState() + + // check for duplicate droplet IDs or IP addresses + if dropletIDs[state.ID] { + errors <- fmt.Errorf("duplicate droplet ID found: %d", state.ID) + } + dropletIDs[state.ID] = true + + ip, err := task.GetIP(ctx) + if err == nil { + if ipAddresses[ip] { + errors <- fmt.Errorf("duplicate IP address found: %s", ip) + } + ipAddresses[ip] = true + } + + taskMutex.Unlock() + + tasks <- doTask + }(i) + } + + wg.Wait() + close(errors) + + // Check for any task creation errors before proceeding + for err := range errors { + require.NoError(t, err) + } + + require.Equal(t, numTasks, len(p.state.TaskStates), "Provider state should contain all tasks") + + // Collect all tasks in a slice before closing the channel + var tasksToCleanup []*Task + close(tasks) + for task := range tasks { + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status, "All tasks should be in running state") + + state := task.GetState() + require.NotEmpty(t, state.ID, "Task should have a droplet ID") + require.NotEmpty(t, state.Name, "Task should have a name") + require.NotEmpty(t, state.Definition.ContainerName, "Task should have a container name") + tasksToCleanup = append(tasksToCleanup, task) + } + + // test cleanup + var cleanupWg sync.WaitGroup + cleanupErrors := make(chan error, numTasks) + + for _, task := range tasksToCleanup { + cleanupWg.Add(1) + go func(t *Task) { + defer cleanupWg.Done() + if err := t.Destroy(ctx); err != nil { + cleanupErrors <- fmt.Errorf("cleanup error: %v", err) + return + } + + // Wait for the task to be removed from the provider state + err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { + taskMutex.Lock() + defer taskMutex.Unlock() + _, exists := p.state.TaskStates[t.GetState().ID] + return !exists, nil + }) + if err != nil { + cleanupErrors <- fmt.Errorf("task state cleanup error: %v", err) + } + }(task) + } + + cleanupWg.Wait() + close(cleanupErrors) + + for err := range cleanupErrors { + require.NoError(t, err) + } + + // Wait for provider state to be empty + err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { + return len(p.state.TaskStates) == 0, nil + }) + require.NoError(t, err, "Provider state should be empty after cleanup") + + // Teardown the provider to clean up resources + err = p.Teardown(ctx) + require.NoError(t, err) + + mockDO.AssertExpectations(t) + for _, client := range mockDockerClients { + client.(*mocks.DockerClient).AssertExpectations(t) + } +} diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index dbcdc0a6..0b604dee 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -98,7 +98,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.petriTag)} - key, res, err := p.doClient.Keys.Create(ctx, req) + key, res, err := p.doClient.CreateKey(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/tag.go b/core/provider/digitalocean/tag.go index db1e9fae..9d8654fb 100644 --- a/core/provider/digitalocean/tag.go +++ b/core/provider/digitalocean/tag.go @@ -12,7 +12,7 @@ func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, er Name: tagName, } - tag, res, err := p.doClient.Tags.Create(ctx, req) + tag, res, err := p.doClient.CreateTag(ctx, req) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index ac7eb39b..884a8610 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -7,13 +7,17 @@ import ( "io" "net" "path" + "sync" "time" + "golang.org/x/crypto/ssh" + "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" dockerclient "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" + "github.com/pkg/errors" "github.com/pkg/sftp" "github.com/spf13/afero" "github.com/spf13/afero/sftpfs" @@ -23,88 +27,30 @@ import ( "github.com/skip-mev/petri/core/v3/util" ) -func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) { - if err := definition.ValidateBasic(); err != nil { - return "", fmt.Errorf("failed to validate task definition: %w", err) - } - - if definition.ProviderSpecificConfig == nil { - return "", fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) - } - - doConfig := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) - - if err := doConfig.ValidateBasic(); err != nil { - return "", fmt.Errorf("could not cast digitalocean specific config: %w", err) - } - - logger = logger.Named("digitalocean_provider") - - logger.Info("creating droplet", zap.String("name", definition.Name)) - - droplet, err := p.CreateDroplet(ctx, definition) - if err != nil { - return "", err - } - - ip, err := p.GetIP(ctx, droplet.Name) - if err != nil { - return "", err - } - - logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) - - dockerClient, err := p.getDropletDockerClient(ctx, droplet.Name) - defer dockerClient.Close() // nolint - - if err != nil { - return "", err - } - - _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) - if err != nil { - logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) - if err := p.pullImage(ctx, dockerClient, definition.Image.Image); err != nil { - return "", err - } - } +type TaskState struct { + ID int `json:"id"` + Name string `json:"name"` + Definition provider.TaskDefinition `json:"definition"` + Status provider.TaskStatus `json:"status"` + ProviderName string `json:"provider_name"` +} - _, err = dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, - Tty: false, - Hostname: definition.Name, - Labels: map[string]string{ - providerLabelName: p.name, - }, - Env: convertEnvMapToList(definition.Environment), - }, &container.HostConfig{ - Mounts: []mount.Mount{ - { - Type: mount.TypeBind, - Source: "/docker_volumes", - Target: definition.DataDir, - }, - }, - NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) - if err != nil { - return "", err - } +type Task struct { + state *TaskState + stateMu sync.Mutex - return droplet.Name, nil + provider *Provider + logger *zap.Logger + sshKeyPair *SSHKeyPair + sshClient *ssh.Client + doClient DoClient + dockerClient DockerClient } -func (p *Provider) StartTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint +var _ provider.TaskI = (*Task)(nil) - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Start(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -113,23 +59,27 @@ func (p *Provider) StartTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.state.Name) } containerID := containers[0].ID - err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) + err = t.dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) if err != nil { return err } err = util.WaitForCondition(ctx, time.Second*300, time.Millisecond*100, func() (bool, error) { - status, err := p.GetTaskStatus(ctx, taskName) + status, err := t.GetStatus(ctx) if err != nil { return false, err } if status == provider.TASK_RUNNING { + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_RUNNING return true, nil } @@ -139,14 +89,8 @@ func (p *Provider) StartTask(ctx context.Context, taskName string) error { return err } -func (p *Provider) StopTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Stop(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -155,30 +99,49 @@ func (p *Provider) StopTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.state.Name) } - return dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_STOPPED + return t.dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) } -func (p *Provider) ModifyTask(ctx context.Context, taskName string, definition provider.TaskDefinition) error { - return nil +func (t *Task) Initialize(ctx context.Context) error { + panic("implement me") +} + +func (t *Task) Modify(ctx context.Context, definition provider.TaskDefinition) error { + panic("implement me") } -func (p *Provider) DestroyTask(ctx context.Context, taskName string) error { - logger := p.logger.With(zap.String("task", taskName)) +func (t *Task) Destroy(ctx context.Context) error { + logger := t.logger.With(zap.String("task", t.state.Name)) logger.Info("deleting task") + defer t.dockerClient.Close() - err := p.deleteDroplet(ctx, taskName) + err := t.deleteDroplet(ctx) if err != nil { return err } + // TODO(nadim-az): remove reference to provider in Task struct + if err := t.provider.removeTask(ctx, t.state.ID); err != nil { + return err + } return nil } -func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider.TaskStatus, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetState() TaskState { + t.stateMu.Lock() + defer t.stateMu.Unlock() + return *t.state +} + +func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -187,14 +150,7 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STOPPED, nil } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return provider.TASK_STATUS_UNDEFINED, err - } - - defer dockerClient.Close() - - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -203,23 +159,21 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider } if len(containers) != 1 { - return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", taskName) + return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", t.state.Name) } - container, err := dockerClient.ContainerInspect(ctx, containers[0].ID) + c, err := t.dockerClient.ContainerInspect(ctx, containers[0].ID) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } - switch state := container.State.Status; state { + switch state := c.State.Status; state { case "created": return provider.TASK_STOPPED, nil case "running": return provider.TASK_RUNNING, nil case "paused": return provider.TASK_PAUSED, nil - case "restarting": - return provider.TASK_RUNNING, nil // todo(zygimantass): is this sane? case "removing": return provider.TASK_STOPPED, nil case "exited": @@ -231,10 +185,10 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STATUS_UNDEFINED, nil } -func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath string, content []byte) error { +func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) if err != nil { return err } @@ -266,10 +220,10 @@ func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath strin return nil } -func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string) ([]byte, error) { +func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) if err != nil { return nil, err } @@ -293,12 +247,12 @@ func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string return content, nil } -func (p *Provider) DownloadDir(ctx context.Context, s string, s2 string, s3 string) error { +func (t *Task) DownloadDir(ctx context.Context, s string, s2 string) error { panic("implement me") } -func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetIP(ctx context.Context) (string, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return "", err @@ -307,8 +261,8 @@ func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { return droplet.PublicIPv4() } -func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port string) (string, error) { - ip, err := p.GetIP(ctx, taskName) +func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { + ip, err := t.GetIP(ctx) if err != nil { return "", err } @@ -316,14 +270,8 @@ func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port return net.JoinHostPort(ip, port), nil } -func (p *Provider) RunCommand(ctx context.Context, taskName string, command []string) (string, string, int, error) { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return "", "", 0, err - } - - defer dockerClient.Close() - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) RunCommand(ctx context.Context, command []string) (string, string, int, error) { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -332,14 +280,14 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st } if len(containers) != 1 { - return "", "", 0, fmt.Errorf("could not find container for %s", taskName) + return "", "", 0, fmt.Errorf("could not find container for %s", t.state.Name) } id := containers[0].ID - p.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) + t.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) - exec, err := dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: command, @@ -348,7 +296,7 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { return "", "", 0, err } @@ -365,7 +313,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -390,77 +338,74 @@ loop: return stdout.String(), stderr.String(), lastExitCode, nil } -func (p *Provider) RunCommandWhileStopped(ctx context.Context, taskName string, definition provider.TaskDefinition, command []string) (string, string, int, error) { - if err := definition.ValidateBasic(); err != nil { +func (t *Task) RunCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { + if err := t.state.Definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - p.logger.Error("failed to get docker client", zap.Error(err), zap.String("taskName", taskName)) - return "", "", 0, err - } + t.stateMu.Lock() + defer t.stateMu.Unlock() - definition.Entrypoint = []string{"sh", "-c"} - definition.Command = []string{"sleep 36000"} - definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", definition.Name, util.RandomString(5), time.Now().Unix()) - definition.Ports = []string{} + t.state.Definition.Entrypoint = []string{"sh", "-c"} + t.state.Definition.Command = []string{"sleep 36000"} + t.state.Definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", t.state.Definition.Name, util.RandomString(5), time.Now().Unix()) + t.state.Definition.Ports = []string{} - createdContainer, err := dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, + createdContainer, err := t.dockerClient.ContainerCreate(ctx, &container.Config{ + Image: t.state.Definition.Image.Image, + Entrypoint: t.state.Definition.Entrypoint, + Cmd: t.state.Definition.Command, Tty: false, - Hostname: definition.Name, + Hostname: t.state.Definition.Name, Labels: map[string]string{ - providerLabelName: p.name, + providerLabelName: t.state.ProviderName, }, - Env: convertEnvMapToList(definition.Environment), + Env: convertEnvMapToList(t.state.Definition.Environment), }, &container.HostConfig{ Mounts: []mount.Mount{ { Type: mount.TypeBind, Source: "/docker_volumes", - Target: definition.DataDir, + Target: t.state.Definition.DataDir, }, }, NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) + }, nil, nil, t.state.Definition.ContainerName) if err != nil { - p.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } defer func() { - if _, err := dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { - // auto-removed, but not detected as autoremoved + if _, err := t.dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { + // container was auto-removed, no need to remove it again return } - if err := dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { - p.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", taskName), zap.String("id", createdContainer.ID)) + if err := t.dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { + t.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", t.state.Name), zap.String("id", createdContainer.ID)) } }() - if err := startContainerWithBlock(ctx, dockerClient, createdContainer.ID); err != nil { - p.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", taskName)) + if err := startContainerWithBlock(ctx, t.dockerClient, createdContainer.ID); err != nil { + t.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } // wait for container start - exec, err := dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ AttachStdout: true, AttachStderr: true, - Cmd: command, + Cmd: cmd, }) if err != nil { - p.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { - p.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", t.state.Name)) return "", "", 0, err } @@ -476,7 +421,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -500,7 +445,7 @@ loop: return stdout.String(), stderr.String(), lastExitCode, err } -func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Client, containerID string) error { +func startContainerWithBlock(ctx context.Context, dockerClient DockerClient, containerID string) error { // start container if err := dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { return err @@ -532,18 +477,18 @@ func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Cli } } -func (p *Provider) pullImage(ctx context.Context, dockerClient *dockerclient.Client, img string) error { - p.logger.Info("pulling image", zap.String("image", img)) +func pullImage(ctx context.Context, dockerClient DockerClient, logger *zap.Logger, img string) error { + logger.Info("pulling image", zap.String("image", img)) resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) if err != nil { - return err + return errors.Wrap(err, "failed to pull docker image") } defer resp.Close() // throw away the image pull stdout response _, err = io.Copy(io.Discard, resp) if err != nil { - return err + return errors.Wrap(err, "failed to pull docker image") } return nil } diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go new file mode 100644 index 00000000..9d0b94fc --- /dev/null +++ b/core/provider/digitalocean/task_test.go @@ -0,0 +1,828 @@ +package digitalocean + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "net/http" + "testing" + "time" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// mockConn implements net.Conn interface for testing +type mockConn struct { + *bytes.Buffer +} + +// mockNotFoundError implements errdefs.ErrNotFound interface for testing +type mockNotFoundError struct { + error +} + +func (e mockNotFoundError) NotFound() {} + +func (m mockConn) Close() error { return nil } +func (m mockConn) LocalAddr() net.Addr { return nil } +func (m mockConn) RemoteAddr() net.Addr { return nil } +func (m mockConn) SetDeadline(t time.Time) error { return nil } +func (m mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } + +func TestTaskLifecycle(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ + ID: "test-container-id", + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockDocker.On("ContainerStop", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + err = task.Stop(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_STOPPED, task.GetState().Status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskRunCommand(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + testContainer := types.Container{ + ID: "test-testContainer-id", + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{testContainer}, nil) + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskRunCommandWhileStopped(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil).Once() + + mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskGetIP(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + expectedIP := "1.2.3.4" + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: expectedIP, + }, + }, + }, + } + + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + ip, err := task.GetIP(ctx) + require.NoError(t, err) + require.Equal(t, expectedIP, ip) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s:80", expectedIP), externalAddr) + + mockDO.AssertExpectations(t) +} + +func TestTaskDestroy(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("DeleteDropletByID", mock.Anything, droplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("Close").Return(nil) + + provider := &Provider{ + state: &ProviderState{ + TaskStates: make(map[int]*TaskState), + }, + } + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + provider: provider, + } + + provider.state.TaskStates[task.state.ID] = task.state + + err := task.Destroy(ctx) + require.NoError(t, err) + require.Empty(t, provider.state.TaskStates) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container exists + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, nil).Once() + + // container should be removed since it exists + mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +// this tests the case where the docker container is auto removed before cleanup doesnt return an error +func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: "test-container-id"} + mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container not found, so ContainerRemove should not be called + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, mockNotFoundError{fmt.Errorf("Error: No such container: test-container-id")}).Once() + mockDocker.AssertNotCalled(t, "ContainerRemove", mock.Anything, mock.Anything, mock.Anything) + + task := &Task{ + state: &TaskState{ + ID: 1, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) +} + +func TestTaskExposingPort(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ + ID: "test-container-id", + Ports: []types.Port{ + { + PrivatePort: 80, + PublicPort: 80, + Type: "tcp", + }, + }, + } + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: droplet.ID, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, "1.2.3.4:80", externalAddr) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", externalAddr), nil) + require.NoError(t, err) + require.NotEmpty(t, req) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestGetStatus(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + tests := []struct { + name string + dropletStatus string + containerState string + setupMocks func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) + expectedStatus provider.TaskStatus + expectError bool + }{ + { + name: "droplet not active", + dropletStatus: "off", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "off", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container running", + dropletStatus: "active", + containerState: "running", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_RUNNING, + expectError: false, + }, + { + name: "container paused", + dropletStatus: "active", + containerState: "paused", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "paused", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_PAUSED, + expectError: false, + }, + { + name: "container stopped state", + dropletStatus: "active", + containerState: "exited", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "no containers found", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{}, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container inspect error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{}, fmt.Errorf("inspect error")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container removing", + dropletStatus: "active", + containerState: "removing", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "removing", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container dead", + dropletStatus: "active", + containerState: "dead", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "dead", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container created", + dropletStatus: "active", + containerState: "created", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "created", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "unknown container status", + dropletStatus: "active", + containerState: "unknown_status", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + container := types.Container{ID: "test-container-id"} + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) + mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "unknown_status", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: false, + }, + { + name: "getDroplet error", + dropletStatus: "", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("failed to get droplet")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "containerList error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + } + mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed to list containers")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDocker := mocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + tt.setupMocks(mockDocker, mockDO) + + task := &Task{ + state: &TaskState{ + ID: 123, + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + status, err := task.GetStatus(ctx) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expectedStatus, status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) + }) + } +} From 56d319913ea02cbcd4a25d4c80ca5ea566cf04b8 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Mon, 13 Jan 2025 18:27:22 +0200 Subject: [PATCH 03/50] cleanup --- core/provider/digitalocean/droplet.go | 6 +- .../examples/digitalocean_cosmos_hub.go | 73 +++++++++++++++++++ core/provider/digitalocean/provider.go | 4 +- core/provider/digitalocean/task.go | 4 +- core/provider/provider.go | 1 + 5 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 core/provider/digitalocean/examples/digitalocean_cosmos_hub.go diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index b9c9423d..010ba384 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -74,12 +74,12 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } if d.Status != "active" { - return false, nil + return false, errors.Errorf("droplet failed to be active after 10 minutes. Current status: %s", d.Status) } ip, err := d.PublicIPv4() if err != nil { - return false, nil + return false, err } if p.dockerClients[ip] == nil { @@ -93,7 +93,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe _, err = p.dockerClients[ip].Ping(ctx) if err != nil { - return false, nil + return false, errors.Errorf("failed to ping docker client at: %s", ip) } p.logger.Info("droplet is active", zap.Duration("after", time.Since(start)), zap.String("task", definition.Name)) diff --git a/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go b/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go new file mode 100644 index 00000000..ac3a5b85 --- /dev/null +++ b/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go @@ -0,0 +1,73 @@ +package main + +import ( + "context" + "os" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean" + "go.uber.org/zap" +) + +func main() { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + doToken := os.Getenv("DO_API_TOKEN") + if doToken == "" { + logger.Fatal("DO_API_TOKEN environment variable not set") + } + + imageID := os.Getenv("DO_IMAGE_ID") + if imageID == "" { + logger.Fatal("DO_IMAGE_ID environment variable not set") + } + + sshKeyPair, err := digitalocean.MakeSSHKeyPair() + if err != nil { + logger.Fatal("failed to create SSH key pair", zap.Error(err)) + } + + doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{}, sshKeyPair) + if err != nil { + logger.Fatal("failed to create DigitalOcean provider", zap.Error(err)) + } + + taskDef := provider.TaskDefinition{ + Name: "cosmos-hub-node", + ContainerName: "cosmos-hub-node", + Image: provider.ImageDefinition{ + Image: "ghcr.io/cosmos/gaia:v21.0.1", + UID: "1000", + GID: "1000", + }, + Ports: []string{"26656", "26657", "26660", "1317", "9090"}, + Environment: map[string]string{}, + DataDir: "/root/.gaia", + Entrypoint: []string{"/bin/sh", "-c"}, + Command: []string{ + "start", + }, + ProviderSpecificConfig: digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": imageID, + }, + } + + logger.Info("Creating new task.") + + task, err := doProvider.CreateTask(ctx, taskDef) + if err != nil { + logger.Fatal("failed to create task", zap.Error(err)) + } + + logger.Info("Successfully created task. Starting task now: ") + + err = task.Start(ctx) + if err != nil { + logger.Fatal("failed to start task", zap.Error(err)) + } + + logger.Info("Task is successfully running!") +} diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 713f465e..04f59d76 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -222,9 +222,9 @@ func (p *Provider) SerializeTask(ctx context.Context, task provider.TaskI) ([]by return nil, fmt.Errorf("task is not a Docker task") } - dockerTask := task.(*Task) + doTask := task.(*Task) - bz, err := json.Marshal(dockerTask.state) + bz, err := json.Marshal(doTask.state) if err != nil { return nil, err diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 884a8610..509eb621 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -83,7 +83,7 @@ func (t *Task) Start(ctx context.Context) error { return true, nil } - return false, nil + return false, errors.New("task not running after 5 minutes") }) return err @@ -176,6 +176,8 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { return provider.TASK_PAUSED, nil case "removing": return provider.TASK_STOPPED, nil + case "restarting": + return provider.TASK_RESTARTING, nil case "exited": return provider.TASK_STOPPED, nil case "dead": diff --git a/core/provider/provider.go b/core/provider/provider.go index 64213dfb..9292be95 100644 --- a/core/provider/provider.go +++ b/core/provider/provider.go @@ -15,6 +15,7 @@ const ( TASK_RUNNING TASK_STOPPED TASK_PAUSED + TASK_RESTARTING ) // Task is a stateful object that holds the underlying workload's details and tracks the workload's lifecycle From 31b3e3f778ea159e89b273744fe50ccae2821be4 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Mon, 13 Jan 2025 21:22:26 +0200 Subject: [PATCH 04/50] fixes --- core/provider/digitalocean/droplet.go | 2 +- core/provider/digitalocean/provider.go | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 010ba384..ffefb977 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -74,7 +74,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } if d.Status != "active" { - return false, errors.Errorf("droplet failed to be active after 10 minutes. Current status: %s", d.Status) + return false, nil } ip, err := d.PublicIPv4() diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 04f59d76..ad791464 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -67,6 +67,10 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName userIPs = append(userIPs, additionalUserIPS...) + if dockerClients == nil { + dockerClients = make(map[string]DockerClient) + } + digitalOceanProvider := &Provider{ logger: logger.Named("digitalocean_provider"), name: providerName, From 1ef6607e8ee0c9a4dedc4b46355d8851d1afd589 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 14 Jan 2025 00:13:19 +0200 Subject: [PATCH 05/50] mo fixes mo problems --- core/provider/digitalocean/droplet.go | 2 +- core/provider/digitalocean/task.go | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index ffefb977..5e69bbd0 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -93,7 +93,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe _, err = p.dockerClients[ip].Ping(ctx) if err != nil { - return false, errors.Errorf("failed to ping docker client at: %s", ip) + return false, nil } p.logger.Info("droplet is active", zap.Duration("after", time.Since(start)), zap.String("task", definition.Name)) diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 509eb621..4b15cba9 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -83,9 +83,10 @@ func (t *Task) Start(ctx context.Context) error { return true, nil } - return false, errors.New("task not running after 5 minutes") + return false, nil }) + t.logger.Info("Final task status after start", zap.Any("status", t.state.Status)) return err } From 1d29ffb99f9f03a67a5b5c1ce5dc3289532b289c Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 14 Jan 2025 01:50:24 +0200 Subject: [PATCH 06/50] more cleanup after rebase --- core/provider/digitalocean/task.go | 4 ++++ core/provider/provider.go | 1 + core/types/chain.go | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 4b15cba9..f8de5bd4 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -141,6 +141,10 @@ func (t *Task) GetState() TaskState { return *t.state } +func (t *Task) GetDefinition() provider.TaskDefinition { + return t.GetState().Definition +} + func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { droplet, err := t.getDroplet(ctx) if err != nil { diff --git a/core/provider/provider.go b/core/provider/provider.go index 9292be95..f7c02fdb 100644 --- a/core/provider/provider.go +++ b/core/provider/provider.go @@ -50,6 +50,7 @@ type TaskI interface { GetExternalAddress(context.Context, string) (string, error) RunCommand(context.Context, []string) (string, string, int, error) + RunCommandWhileStopped(context.Context, []string) (string, string, int, error) } type ProviderI interface { diff --git a/core/types/chain.go b/core/types/chain.go index 7d2afed6..490ad095 100644 --- a/core/types/chain.go +++ b/core/types/chain.go @@ -100,7 +100,7 @@ func (c ChainConfig) GetGenesisDelegation() *big.Int { return c.GenesisDelegation } -func (c *ChainConfig) ValidateBasic() error { +func (c ChainConfig) ValidateBasic() error { if c.Denom == "" { return fmt.Errorf("denom cannot be empty") } From 10e6b3420e1ba4d5b036baba1dc99d2ada109b05 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 14 Jan 2025 01:51:46 +0200 Subject: [PATCH 07/50] more cleanup after rebase --- core/provider/digitalocean/droplet.go | 2 +- .../examples/digitalocean_cosmos_hub.go | 59 ++++++++++++------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 5e69bbd0..e6d8d416 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -106,7 +106,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe end := time.Now() - p.logger.Info("droplet %s is ready after %s", zap.String("name", droplet.Name), zap.Duration("took", end.Sub(start))) + p.logger.Info(fmt.Sprintf("droplet %s is ready after %s", droplet.Name, end.Sub(start))) return droplet, nil } diff --git a/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go b/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go index ac3a5b85..ca8a17e1 100644 --- a/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go +++ b/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go @@ -4,6 +4,10 @@ import ( "context" "os" + petritypes "github.com/skip-mev/petri/core/v2/types" + "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/skip-mev/petri/cosmos/v2/node" + "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean" "go.uber.org/zap" @@ -33,41 +37,52 @@ func main() { logger.Fatal("failed to create DigitalOcean provider", zap.Error(err)) } - taskDef := provider.TaskDefinition{ - Name: "cosmos-hub-node", - ContainerName: "cosmos-hub-node", + chainConfig := petritypes.ChainConfig{ + ChainId: "cosmoshub-4", + NumValidators: 1, + NumNodes: 1, + BinaryName: "gaiad", + Denom: "uatom", + Decimals: 6, + GasPrices: "0.0025uatom", Image: provider.ImageDefinition{ Image: "ghcr.io/cosmos/gaia:v21.0.1", UID: "1000", GID: "1000", }, - Ports: []string{"26656", "26657", "26660", "1317", "9090"}, - Environment: map[string]string{}, - DataDir: "/root/.gaia", - Entrypoint: []string{"/bin/sh", "-c"}, - Command: []string{ - "start", - }, - ProviderSpecificConfig: digitalocean.DigitalOceanTaskConfig{ - "size": "s-2vcpu-4gb", - "region": "ams3", - "image_id": imageID, + HomeDir: "/root/.gaia", + Bech32Prefix: "cosmos", + CoinType: "118", + UseGenesisSubCommand: true, + NodeCreator: node.CreateNode, + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig petritypes.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": imageID, + } + def.ProviderSpecificConfig = doConfig + return def }, } - logger.Info("Creating new task.") - - task, err := doProvider.CreateTask(ctx, taskDef) + logger.Info("Creating chain") + cosmosChain, err := chain.CreateChain(ctx, logger, doProvider, chainConfig) if err != nil { - logger.Fatal("failed to create task", zap.Error(err)) + logger.Fatal("failed to create chain", zap.Error(err)) } - logger.Info("Successfully created task. Starting task now: ") + logger.Info("Initializing chain") + err = cosmosChain.Init(ctx) + if err != nil { + logger.Fatal("failed to initialize chain", zap.Error(err)) + } - err = task.Start(ctx) + logger.Info("Waiting for chain to produce blocks") + err = cosmosChain.WaitForBlocks(ctx, 1) if err != nil { - logger.Fatal("failed to start task", zap.Error(err)) + logger.Fatal("failed waiting for blocks", zap.Error(err)) } - logger.Info("Task is successfully running!") + logger.Info("Chain is successfully running!") } From ea45999f2450ca381099c0d9692d1db29ba64d6d Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 14 Jan 2025 23:39:41 +0200 Subject: [PATCH 08/50] genesis functions should use RunCommandWhileStopped --- ...n_cosmos_hub.go => digitalocean_simapp.go} | 44 +++++++++++++------ core/provider/digitalocean/task.go | 4 ++ cosmos/node/genesis.go | 6 +-- cosmos/node/init.go | 1 + 4 files changed, 38 insertions(+), 17 deletions(-) rename core/provider/digitalocean/examples/{digitalocean_cosmos_hub.go => digitalocean_simapp.go} (61%) diff --git a/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go b/core/provider/digitalocean/examples/digitalocean_simapp.go similarity index 61% rename from core/provider/digitalocean/examples/digitalocean_cosmos_hub.go rename to core/provider/digitalocean/examples/digitalocean_simapp.go index ca8a17e1..a35b2c5e 100644 --- a/core/provider/digitalocean/examples/digitalocean_cosmos_hub.go +++ b/core/provider/digitalocean/examples/digitalocean_simapp.go @@ -4,6 +4,8 @@ import ( "context" "os" + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/skip-mev/petri/core/v2/types" petritypes "github.com/skip-mev/petri/core/v2/types" "github.com/skip-mev/petri/cosmos/v2/chain" "github.com/skip-mev/petri/cosmos/v2/node" @@ -32,29 +34,29 @@ func main() { logger.Fatal("failed to create SSH key pair", zap.Error(err)) } - doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{}, sshKeyPair) + // Add your IP address below + doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{"INSERT IP ADDRESS HERE"}, sshKeyPair) if err != nil { logger.Fatal("failed to create DigitalOcean provider", zap.Error(err)) } chainConfig := petritypes.ChainConfig{ - ChainId: "cosmoshub-4", + Denom: "stake", + Decimals: 6, NumValidators: 1, NumNodes: 1, - BinaryName: "gaiad", - Denom: "uatom", - Decimals: 6, - GasPrices: "0.0025uatom", + BinaryName: "/usr/bin/simd", Image: provider.ImageDefinition{ - Image: "ghcr.io/cosmos/gaia:v21.0.1", + Image: "interchainio/simapp:latest", UID: "1000", GID: "1000", }, - HomeDir: "/root/.gaia", - Bech32Prefix: "cosmos", - CoinType: "118", - UseGenesisSubCommand: true, - NodeCreator: node.CreateNode, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + NodeCreator: node.CreateNode, NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig petritypes.NodeConfig) provider.TaskDefinition { doConfig := digitalocean.DigitalOceanTaskConfig{ "size": "s-2vcpu-4gb", @@ -64,6 +66,13 @@ func main() { def.ProviderSpecificConfig = doConfig return def }, + WalletConfig: types.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, } logger.Info("Creating chain") @@ -78,11 +87,18 @@ func main() { logger.Fatal("failed to initialize chain", zap.Error(err)) } - logger.Info("Waiting for chain to produce blocks") + logger.Info("Chain is successfully running! Waiting for chain to produce blocks") err = cosmosChain.WaitForBlocks(ctx, 1) if err != nil { logger.Fatal("failed waiting for blocks", zap.Error(err)) } - logger.Info("Chain is successfully running!") + // Comment out section below if you want to persist your Digital Ocean resources + logger.Info("Chain has successfully produced required number of blocks. Tearing down Digital Ocean resources.") + err = doProvider.Teardown(ctx) + if err != nil { + logger.Fatal("failed to teardown provider", zap.Error(err)) + } + + logger.Info("All Digital Ocean resources created have been successfully deleted!") } diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index f8de5bd4..e609d90f 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -383,6 +383,8 @@ func (t *Task) RunCommandWhileStopped(ctx context.Context, cmd []string) (string return "", "", 0, err } + t.logger.Debug("container created successfully", zap.String("id", createdContainer.ID), zap.String("taskName", t.state.Name)) + defer func() { if _, err := t.dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { // container was auto-removed, no need to remove it again @@ -399,6 +401,8 @@ func (t *Task) RunCommandWhileStopped(ctx context.Context, cmd []string) (string return "", "", 0, err } + t.logger.Debug("container started successfully", zap.String("id", createdContainer.ID), zap.String("taskName", t.state.Name)) + // wait for container start exec, err := t.dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ AttachStdout: true, diff --git a/cosmos/node/genesis.go b/cosmos/node/genesis.go index d6de7aee..7adf3f34 100644 --- a/cosmos/node/genesis.go +++ b/cosmos/node/genesis.go @@ -71,7 +71,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo command = append(command, "add-genesis-account", address, amount) command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommand(ctx, command) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) n.logger.Debug("add-genesis-account", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -101,7 +101,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommand(ctx, command) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) n.logger.Debug("gentx", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -127,7 +127,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error { command = append(command, "collect-gentxs") - stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand(command...)) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, n.BinCommand(command...)) n.logger.Debug("collect-gentxs", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { diff --git a/cosmos/node/init.go b/cosmos/node/init.go index cd5d9970..92675e29 100644 --- a/cosmos/node/init.go +++ b/cosmos/node/init.go @@ -12,6 +12,7 @@ func (n *Node) InitHome(ctx context.Context) error { n.logger.Info("initializing home", zap.String("name", n.GetDefinition().Name)) stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand([]string{"init", n.GetDefinition().Name, "--chain-id", n.GetChainConfig().ChainId}...)) + n.logger.Debug("init home", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { From d8dd5034881e6c4f6e660f34faf4ca39b4f135ce Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 15 Jan 2025 15:22:45 +0200 Subject: [PATCH 09/50] runcommand should determine task state and how to run container accordingly --- core/provider/digitalocean/task.go | 17 +++++++++++++++-- core/provider/provider.go | 1 - cosmos/node/genesis.go | 6 +++--- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index e609d90f..24f3eddc 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -277,7 +277,20 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err return net.JoinHostPort(ip, port), nil } -func (t *Task) RunCommand(ctx context.Context, command []string) (string, string, int, error) { +func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, int, error) { + status, err := t.GetStatus(ctx) + if err != nil { + return "", "", 0, err + } + + if status != provider.TASK_RUNNING { + return t.runCommandWhileStopped(ctx, cmd) + } + + return t.runCommand(ctx, cmd) +} + +func (t *Task) runCommand(ctx context.Context, command []string) (string, string, int, error) { containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -345,7 +358,7 @@ loop: return stdout.String(), stderr.String(), lastExitCode, nil } -func (t *Task) RunCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { +func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { if err := t.state.Definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } diff --git a/core/provider/provider.go b/core/provider/provider.go index f7c02fdb..9292be95 100644 --- a/core/provider/provider.go +++ b/core/provider/provider.go @@ -50,7 +50,6 @@ type TaskI interface { GetExternalAddress(context.Context, string) (string, error) RunCommand(context.Context, []string) (string, string, int, error) - RunCommandWhileStopped(context.Context, []string) (string, string, int, error) } type ProviderI interface { diff --git a/cosmos/node/genesis.go b/cosmos/node/genesis.go index 7adf3f34..d6de7aee 100644 --- a/cosmos/node/genesis.go +++ b/cosmos/node/genesis.go @@ -71,7 +71,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo command = append(command, "add-genesis-account", address, amount) command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) + stdout, stderr, exitCode, err := n.RunCommand(ctx, command) n.logger.Debug("add-genesis-account", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -101,7 +101,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) + stdout, stderr, exitCode, err := n.RunCommand(ctx, command) n.logger.Debug("gentx", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -127,7 +127,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error { command = append(command, "collect-gentxs") - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, n.BinCommand(command...)) + stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand(command...)) n.logger.Debug("collect-gentxs", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { From ee404573139b9d08b0574b159c25d96706a262ab Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 15 Jan 2025 15:31:14 +0200 Subject: [PATCH 10/50] harden task tests --- core/provider/digitalocean/task_test.go | 458 +++++++++++++++--------- 1 file changed, 288 insertions(+), 170 deletions(-) diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index 9d0b94fc..73b39d19 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -13,6 +13,9 @@ import ( "github.com/digitalocean/godo" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + specs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" "github.com/stretchr/testify/mock" @@ -39,6 +42,15 @@ func (m mockConn) SetDeadline(t time.Time) error { return nil } func (m mockConn) SetReadDeadline(t time.Time) error { return nil } func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } +const ( + testContainerID = "test-container-id" +) + +var ( + testContainer = types.Container{ID: testContainerID} + testDroplet = &godo.Droplet{ID: 123, Status: "active"} +) + func TestTaskLifecycle(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() @@ -59,21 +71,22 @@ func TestTaskLifecycle(t *testing.T) { }, } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - container := types.Container{ - ID: "test-container-id", - } - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "running", }, }, }, nil) - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) - mockDocker.On("ContainerStop", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + mockDocker.On("ContainerStop", ctx, testContainerID, container.StopOptions{}).Return(nil) task := &Task{ state: &TaskState{ @@ -118,20 +131,34 @@ func TestTaskRunCommand(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - testContainer := types.Container{ - ID: "test-testContainer-id", - } - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{testContainer}, nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - execCreateResp := types.IDResponse{ID: "test-exec-id"} - mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + execID := "test-exec-id" + execCreateResp := types.IDResponse{ID: execID} + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} - mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + mockDocker.On("ContainerExecAttach", ctx, execID, container.ExecAttachOptions{}).Return(types.HijackedResponse{ Conn: conn, Reader: bufio.NewReader(conn), }, nil) - mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + mockDocker.On("ContainerExecInspect", ctx, execID).Return(container.ExecInspect{ ExitCode: 0, Running: false, }, nil) @@ -161,6 +188,7 @@ func TestTaskRunCommand(t *testing.T) { require.Empty(t, stderr) mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) } func TestTaskRunCommandWhileStopped(t *testing.T) { @@ -170,32 +198,60 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - createResp := container.CreateResponse{ID: "test-container-id"} - mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) - - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Running: true, }, }, - }, nil).Once() + }, nil).Twice() + + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) execCreateResp := types.IDResponse{ID: "test-exec-id"} - mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} - mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ Conn: conn, Reader: bufio.NewReader(conn), }, nil) - mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ ExitCode: 0, Running: false, }, nil) - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Running: false, @@ -203,7 +259,7 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { }, }, nil).Once() - mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + mockDocker.On("ContainerRemove", ctx, testContainerID, container.RemoveOptions{Force: true}).Return(nil) task := &Task{ state: &TaskState{ @@ -225,7 +281,7 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { doClient: mockDO, } - _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) require.NoError(t, err) require.Equal(t, 0, exitCode) require.Empty(t, stderr) @@ -285,13 +341,8 @@ func TestTaskDestroy(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("DeleteDropletByID", mock.Anything, droplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, testDroplet.ID).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("DeleteDropletByID", ctx, testDroplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) mockDocker.On("Close").Return(nil) provider := &Provider{ @@ -302,7 +353,7 @@ func TestTaskDestroy(t *testing.T) { task := &Task{ state: &TaskState{ - ID: droplet.ID, + ID: testDroplet.ID, Name: "test-task", ProviderName: "test-provider", }, @@ -329,12 +380,46 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - createResp := container.CreateResponse{ID: "test-container-id"} - mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ + {ID: testContainerID}, + }, nil).Once() + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil).Once() + + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) // first ContainerInspect for startup check - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Running: true, @@ -343,23 +428,27 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { }, nil).Once() execCreateResp := types.IDResponse{ID: "test-exec-id"} - mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} - mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ Conn: conn, Reader: bufio.NewReader(conn), }, nil) - mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ ExitCode: 0, Running: false, }, nil) // second ContainerInspect for cleanup check - container exists - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, nil).Once() + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, nil).Once() // container should be removed since it exists - mockDocker.On("ContainerRemove", mock.Anything, mock.Anything, container.RemoveOptions{Force: true}).Return(nil) + mockDocker.On("ContainerRemove", ctx, testContainerID, container.RemoveOptions{Force: true}).Return(nil) task := &Task{ state: &TaskState{ @@ -381,12 +470,13 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { doClient: mockDO, } - _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) require.NoError(t, err) require.Equal(t, 0, exitCode) require.Empty(t, stderr) mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) } // this tests the case where the docker container is auto removed before cleanup doesnt return an error @@ -397,12 +487,46 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - createResp := container.CreateResponse{ID: "test-container-id"} - mockDocker.On("ContainerCreate", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(createResp, nil) - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ + {ID: testContainerID}, + }, nil).Once() + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil).Once() + + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) // first ContainerInspect for startup check - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Running: true, @@ -411,21 +535,25 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { }, nil).Once() execCreateResp := types.IDResponse{ID: "test-exec-id"} - mockDocker.On("ContainerExecCreate", mock.Anything, mock.Anything, mock.Anything).Return(execCreateResp, nil) + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} - mockDocker.On("ContainerExecAttach", mock.Anything, mock.Anything, mock.Anything).Return(types.HijackedResponse{ + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ Conn: conn, Reader: bufio.NewReader(conn), }, nil) - mockDocker.On("ContainerExecInspect", mock.Anything, mock.Anything).Return(container.ExecInspect{ + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ ExitCode: 0, Running: false, }, nil) // second ContainerInspect for cleanup check - container not found, so ContainerRemove should not be called - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{}, mockNotFoundError{fmt.Errorf("Error: No such container: test-container-id")}).Once() - mockDocker.AssertNotCalled(t, "ContainerRemove", mock.Anything, mock.Anything, mock.Anything) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, mockNotFoundError{fmt.Errorf("Error: No such container: test-container-id")}).Once() + mockDocker.AssertNotCalled(t, "ContainerRemove", ctx, testContainerID, mock.Anything) task := &Task{ state: &TaskState{ @@ -447,12 +575,13 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { doClient: mockDO, } - _, stderr, exitCode, err := task.RunCommandWhileStopped(ctx, []string{"echo", "hello"}) + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) require.NoError(t, err) require.Equal(t, 0, exitCode) require.Empty(t, stderr) mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) } func TestTaskExposingPort(t *testing.T) { @@ -475,10 +604,10 @@ func TestTaskExposingPort(t *testing.T) { }, } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - container := types.Container{ - ID: "test-container-id", + testContainer := types.Container{ + ID: testContainerID, Ports: []types.Port{ { PrivatePort: 80, @@ -487,15 +616,19 @@ func TestTaskExposingPort(t *testing.T) { }, }, } - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "running", }, }, }, nil) - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) task := &Task{ state: &TaskState{ @@ -541,6 +674,14 @@ func TestTaskExposingPort(t *testing.T) { func TestGetStatus(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() + testDropletActive := &godo.Droplet{ + ID: 123, + Status: "active", + } + testDropletOff := &godo.Droplet{ + ID: 123, + Status: "off", + } tests := []struct { name string @@ -555,11 +696,8 @@ func TestGetStatus(t *testing.T) { dropletStatus: "off", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "off", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDO.On("GetDroplet", ctx, testDropletOff.ID).Return(testDropletOff, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) }, expectedStatus: provider.TASK_STOPPED, expectError: false, @@ -569,15 +707,13 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "running", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "running", @@ -593,15 +729,12 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "paused", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "paused", @@ -617,15 +750,13 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "exited", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "exited", @@ -636,53 +767,18 @@ func TestGetStatus(t *testing.T) { expectedStatus: provider.TASK_STOPPED, expectError: false, }, - { - name: "no containers found", - dropletStatus: "active", - containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{}, nil) - }, - expectedStatus: provider.TASK_STATUS_UNDEFINED, - expectError: true, - }, - { - name: "container inspect error", - dropletStatus: "active", - containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{}, fmt.Errorf("inspect error")) - }, - expectedStatus: provider.TASK_STATUS_UNDEFINED, - expectError: true, - }, { name: "container removing", dropletStatus: "active", containerState: "removing", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "removing", @@ -698,15 +794,13 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "dead", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "dead", @@ -722,15 +816,13 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "created", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "created", @@ -746,15 +838,13 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "unknown_status", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - container := types.Container{ID: "test-container-id"} - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{container}, nil) - mockDocker.On("ContainerInspect", mock.Anything, container.ID).Return(types.ContainerJSON{ + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "unknown_status", @@ -765,12 +855,41 @@ func TestGetStatus(t *testing.T) { expectedStatus: provider.TASK_STATUS_UNDEFINED, expectError: false, }, + { + name: "no containers found", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{}, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container inspect error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, fmt.Errorf("inspect error")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, { name: "getDroplet error", dropletStatus: "", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - mockDO.On("GetDroplet", mock.Anything, mock.Anything).Return(nil, nil, fmt.Errorf("failed to get droplet")) + mockDO.On("GetDroplet", ctx, 123).Return(nil, nil, fmt.Errorf("failed to get droplet")) }, expectedStatus: provider.TASK_STATUS_UNDEFINED, expectError: true, @@ -780,12 +899,11 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - droplet := &godo.Droplet{ - ID: 123, - Status: "active", - } - mockDO.On("GetDroplet", mock.Anything, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return(nil, fmt.Errorf("failed to list containers")) + + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return(nil, fmt.Errorf("failed to list containers")) }, expectedStatus: provider.TASK_STATUS_UNDEFINED, expectError: true, From 46ee0b2d71fa8cb331cdbaf7ae2a93e656322800 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 15 Jan 2025 22:40:33 +0200 Subject: [PATCH 11/50] harden provider tests --- core/provider/digitalocean/provider_test.go | 421 +++++++------------- 1 file changed, 154 insertions(+), 267 deletions(-) diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index af84b079..7731aae5 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -13,6 +13,10 @@ import ( "github.com/digitalocean/godo" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + specs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -23,108 +27,47 @@ import ( "github.com/skip-mev/petri/core/v2/util" ) -func setupMockClient(t *testing.T) *mocks.DoClient { - return mocks.NewDoClient(t) -} - -func setupMockDockerClient(t *testing.T) *mocks.DockerClient { - return mocks.NewDockerClient(t) -} - -func TestNewProvider(t *testing.T) { +func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoClient, *mocks.DockerClient) { logger := zap.NewExample() - ctx := context.Background() - - tests := []struct { - name string - token string - additionalIPs []string - sshKeyPair *SSHKeyPair - expectedError bool - mockSetup func(*mocks.DoClient) - }{ - { - name: "valid provider creation", - token: "test-token", - additionalIPs: []string{"1.2.3.4"}, - sshKeyPair: nil, - expectedError: false, - mockSetup: func(m *mocks.DoClient) { - m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - m.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). - Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - m.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - m.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). - Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - }, + mockDO := mocks.NewDoClient(t) + mockDocker := mocks.NewDockerClient(t) + + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", ctx, "ubuntu:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "ubuntu:latest", + Entrypoint: []string{"/bin/bash"}, + Cmd: []string{"-c", "echo hello"}, + Env: []string{"TEST=value"}, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", }, - { - name: "bad token", - token: "foobar", - additionalIPs: []string{}, - sshKeyPair: nil, - expectedError: true, - mockSetup: func(m *mocks.DoClient) { - m.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusUnauthorized}}, fmt.Errorf("unauthorized")) + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "/data", }, }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - mockClient := setupMockClient(t) - mockDocker := setupMockDockerClient(t) - tc.mockSetup(mockClient) - mockDockerClients := map[string]DockerClient{ - "test-ip": mockDocker, - } + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), "test-container").Return(container.CreateResponse{ID: "test-container"}, nil) - provider, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, tc.additionalIPs, tc.sshKeyPair) - if tc.expectedError { - assert.Error(t, err) - assert.Nil(t, provider) - } else { - assert.NoError(t, err) - assert.NotNil(t, provider) - assert.NotEmpty(t, provider.petriTag) - assert.NotEmpty(t, provider.firewallID) - } - }) - } -} + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) -func TestCreateTask(t *testing.T) { - logger := zap.NewExample() - ctx := context.Background() - mockClient := setupMockClient(t) - mockDocker := setupMockDockerClient(t) - - mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) - mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) - mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) - mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) - - mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). - Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). - Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) mockDockerClients := map[string]DockerClient{ "10.0.0.1": mockDocker, } - p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) require.NoError(t, err) - mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). + mockDO.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). Return(&godo.Droplet{ ID: 123, Name: "test-droplet", @@ -139,135 +82,104 @@ func TestCreateTask(t *testing.T) { }, }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). - Return(&godo.Droplet{ - ID: 123, - Name: "test-droplet", - Status: "active", - Networks: &godo.Networks{ - V4: []godo.NetworkV4{ - { - IPAddress: "10.0.0.1", - Type: "public", - }, - }, - }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - tests := []struct { - name string - taskDef provider.TaskDefinition - expectedError bool - }{ - { - name: "valid task creation", - taskDef: provider.TaskDefinition{ - Name: "test-task", - Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, - Entrypoint: []string{"/bin/bash"}, - Command: []string{"-c", "echo hello"}, - Environment: map[string]string{"TEST": "value"}, - DataDir: "/data", - ContainerName: "test-container", - ProviderSpecificConfig: DigitalOceanTaskConfig{ - "size": "s-1vcpu-1gb", - "region": "nyc1", - "image_id": "123456", + mockDO.On("GetDroplet", ctx, 123).Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", }, }, - expectedError: false, }, - { - name: "missing provider specific config", - taskDef: provider.TaskDefinition{ - Name: "test-task", - Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, - ProviderSpecificConfig: nil, - }, - expectedError: true, - }, - { - name: "invalid provider specific config - missing region", - taskDef: provider.TaskDefinition{ - Name: "test-task", - Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, - ProviderSpecificConfig: DigitalOceanTaskConfig{ - "size": "s-1vcpu-1gb", - "image_id": "123456", - }, - }, - expectedError: true, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + return p, mockDO, mockDocker +} + +func TestCreateTask_ValidTask(t *testing.T) { + ctx := context.Background() + p, _, _ := setupTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", }, } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - task, err := p.CreateTask(ctx, tc.taskDef) - if tc.expectedError { - assert.Error(t, err) - assert.Nil(t, task) - } else { - assert.NoError(t, err) - assert.NotNil(t, task) - } - }) - } + task, err := p.CreateTask(ctx, taskDef) + assert.NoError(t, err) + assert.Equal(t, task.GetDefinition(), taskDef) + assert.NotNil(t, task) } -func TestSerializeAndRestore(t *testing.T) { +func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { logger := zap.NewExample() - ctx := context.Background() - mockClient := setupMockClient(t) - mockDocker := setupMockDockerClient(t) + mockDO := mocks.NewDoClient(t) + mockDocker := mocks.NewDockerClient(t) + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDockerClients := map[string]DockerClient{ "10.0.0.1": mockDocker, } - mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil) - mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) - mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil) - mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: "test-container"}, nil) - - mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). - Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). - Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). - Return(&godo.Droplet{ - ID: 123, - Name: "test-droplet", - Status: "active", - Networks: &godo.Networks{ - V4: []godo.NetworkV4{ - { - IPAddress: "10.0.0.1", - Type: "public", - }, - }, - }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) - mockClient.On("GetDroplet", mock.Anything, mock.AnythingOfType("int")). - Return(&godo.Droplet{ - ID: 123, - Name: "test-droplet", - Status: "active", - Networks: &godo.Networks{ - V4: []godo.NetworkV4{ - { - IPAddress: "10.0.0.1", - Type: "public", - }, - }, - }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + return p +} - p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, mockDockerClients, []string{}, nil) - require.NoError(t, err) +func TestCreateTask_MissingProviderConfig(t *testing.T) { + ctx := context.Background() + p := setupValidationTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: nil, + } + + task, err := p.CreateTask(ctx, taskDef) + assert.Error(t, err) + assert.Nil(t, task) +} + +func TestCreateTask_MissingRegion(t *testing.T) { + ctx := context.Background() + p := setupValidationTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + assert.Error(t, err) + assert.Nil(t, task) +} + +func TestSerializeAndRestore(t *testing.T) { + ctx := context.Background() + p, mockDO, mockDocker := setupTestProvider(t, ctx) providerData, err := p.SerializeProvider(ctx) assert.NoError(t, err) @@ -298,38 +210,21 @@ func TestSerializeAndRestore(t *testing.T) { deserializedTask, err := p.DeserializeTask(ctx, taskData) assert.NoError(t, err) assert.NotNil(t, deserializedTask) -} -func TestTeardown(t *testing.T) { - logger := zap.NewExample() - ctx := context.Background() - mockClient := setupMockClient(t) - mockDocker := setupMockDockerClient(t) - - mockClient.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). - Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockClient.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). - Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - mockClient.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockClient.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - p, err := NewProviderWithClient(ctx, logger, "test-provider", mockClient, nil, []string{}, nil) - require.NoError(t, err) + t1 := task.(*Task) + t2 := deserializedTask.(*Task) - err = p.Teardown(ctx) - assert.NoError(t, err) - mockClient.AssertExpectations(t) + if configMap, ok := t2.state.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { + doConfig := make(DigitalOceanTaskConfig) + for k, v := range configMap { + doConfig[k] = v.(string) + } + t2.state.Definition.ProviderSpecificConfig = doConfig + } + + assert.Equal(t, t1.state, t2.state) + + mockDO.AssertExpectations(t) mockDocker.AssertExpectations(t) } @@ -347,20 +242,22 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDockerClients[ip] = mockDocker - mockDocker.On("Ping", mock.Anything).Return(types.Ping{}, nil).Once() - mockDocker.On("ImageInspectWithRaw", mock.Anything, mock.AnythingOfType("string")).Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() - mockDocker.On("ImagePull", mock.Anything, mock.AnythingOfType("string"), mock.AnythingOfType("image.PullOptions")).Return(io.NopCloser(strings.NewReader("")), nil).Once() - mockDocker.On("ContainerCreate", mock.Anything, mock.AnythingOfType("*container.Config"), mock.AnythingOfType("*container.HostConfig"), - mock.AnythingOfType("*network.NetworkingConfig"), mock.AnythingOfType("*v1.Platform"), - mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() - mockDocker.On("ContainerStart", mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() - mockDocker.On("ContainerList", mock.Anything, mock.Anything).Return([]types.Container{ + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() + mockDocker.On("ImageInspectWithRaw", ctx, "nginx:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + mockDocker.On("ImagePull", ctx, "nginx:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil).Once() + mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { + return config.Image == "nginx:latest" + }), mock.Anything, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() + mockDocker.On("ContainerStart", ctx, fmt.Sprintf("container-%d", i), container.StartOptions{}).Return(nil).Once() + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ { ID: fmt.Sprintf("container-%d", i), State: "running", }, }, nil).Times(3) - mockDocker.On("ContainerInspect", mock.Anything, mock.Anything).Return(types.ContainerJSON{ + mockDocker.On("ContainerInspect", ctx, fmt.Sprintf("container-%d", i)).Return(types.ContainerJSON{ ContainerJSONBase: &types.ContainerJSONBase{ State: &types.ContainerState{ Status: "running", @@ -370,14 +267,13 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { mockDocker.On("Close").Return(nil).Once() } - mockDO.On("CreateTag", mock.Anything, mock.AnythingOfType("*godo.TagCreateRequest")). - Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("CreateFirewall", mock.Anything, mock.AnythingOfType("*godo.FirewallRequest")). - Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("GetKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")). Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockDO.On("CreateKey", mock.Anything, mock.AnythingOfType("*godo.KeyCreateRequest")). - Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) require.NoError(t, err) @@ -405,20 +301,19 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { }, } - mockDO.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusCreated}}, nil).Once() + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusCreated}}, nil).Once() // we cant predict how many times GetDroplet will be called exactly as the provider polls waiting for its creation - mockDO.On("GetDroplet", mock.Anything, dropletID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() - mockDO.On("DeleteDropletByID", mock.Anything, dropletID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() + mockDO.On("GetDroplet", ctx, dropletID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO.On("DeleteDropletByID", ctx, dropletID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() } - // these are called once per provider, not per task - mockDO.On("DeleteDropletByTag", mock.Anything, mock.AnythingOfType("string")). + mockDO.On("DeleteDropletByTag", ctx, mock.AnythingOfType("string")). Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteFirewall", mock.Anything, mock.AnythingOfType("string")). + mockDO.On("DeleteFirewall", ctx, mock.AnythingOfType("string")). Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteKeyByFingerprint", mock.Anything, mock.AnythingOfType("string")). + mockDO.On("DeleteKeyByFingerprint", ctx, mock.AnythingOfType("string")). Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteTag", mock.Anything, mock.AnythingOfType("string")). + mockDO.On("DeleteTag", ctx, mock.AnythingOfType("string")). Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() for i := 0; i < numTasks; i++ { @@ -438,7 +333,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { ProviderSpecificConfig: DigitalOceanTaskConfig{ "size": "s-1vcpu-1gb", "region": "nyc1", - "image_id": "123456789", // Using a string for image_id + "image_id": "123456789", }, }) @@ -452,12 +347,10 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { return } - // thread-safe recording of task details taskMutex.Lock() doTask := task.(*Task) state := doTask.GetState() - // check for duplicate droplet IDs or IP addresses if dropletIDs[state.ID] { errors <- fmt.Errorf("duplicate droplet ID found: %d", state.ID) } @@ -480,14 +373,12 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { wg.Wait() close(errors) - // Check for any task creation errors before proceeding for err := range errors { require.NoError(t, err) } require.Equal(t, numTasks, len(p.state.TaskStates), "Provider state should contain all tasks") - // Collect all tasks in a slice before closing the channel var tasksToCleanup []*Task close(tasks) for task := range tasks { @@ -502,7 +393,6 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { tasksToCleanup = append(tasksToCleanup, task) } - // test cleanup var cleanupWg sync.WaitGroup cleanupErrors := make(chan error, numTasks) @@ -515,7 +405,6 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { return } - // Wait for the task to be removed from the provider state err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { taskMutex.Lock() defer taskMutex.Unlock() @@ -535,13 +424,11 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { require.NoError(t, err) } - // Wait for provider state to be empty err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { return len(p.state.TaskStates) == 0, nil }) require.NoError(t, err, "Provider state should be empty after cleanup") - // Teardown the provider to clean up resources err = p.Teardown(ctx) require.NoError(t, err) From c948b2c703d4d9f841294a62bf88c6e331cf8fef Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 16 Jan 2025 15:37:48 +0200 Subject: [PATCH 12/50] fix --- core/provider/digitalocean/provider_test.go | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 7731aae5..1ce42aa7 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -364,9 +364,8 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { ipAddresses[ip] = true } - taskMutex.Unlock() - tasks <- doTask + taskMutex.Unlock() }(i) } @@ -404,13 +403,6 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { cleanupErrors <- fmt.Errorf("cleanup error: %v", err) return } - - err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { - taskMutex.Lock() - defer taskMutex.Unlock() - _, exists := p.state.TaskStates[t.GetState().ID] - return !exists, nil - }) if err != nil { cleanupErrors <- fmt.Errorf("task state cleanup error: %v", err) } From b9acaca6c9e3bbf3074d22f301a60d9ae55ca3a3 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 16 Jan 2025 21:37:48 +0200 Subject: [PATCH 13/50] move serializable fields to state --- core/provider/digitalocean/droplet.go | 6 ++-- core/provider/digitalocean/firewall.go | 8 +++--- core/provider/digitalocean/provider.go | 39 +++++++++++++------------- core/provider/digitalocean/ssh.go | 2 +- 4 files changed, 28 insertions(+), 27 deletions(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index e6d8d416..c2f4a9a9 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -42,7 +42,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } req := &godo.DropletCreateRequest{ - Name: fmt.Sprintf("%s-%s", p.petriTag, definition.Name), + Name: fmt.Sprintf("%s-%s", p.state.petriTag, definition.Name), Region: doConfig["region"], Size: doConfig["size"], Image: godo.DropletCreateImage{ @@ -50,10 +50,10 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe }, SSHKeys: []godo.DropletCreateSSHKey{ { - Fingerprint: p.sshKeyPair.Fingerprint, + Fingerprint: p.state.sshKeyPair.Fingerprint, }, }, - Tags: []string{p.petriTag}, + Tags: []string{p.state.petriTag}, } droplet, res, err := p.doClient.CreateDroplet(ctx, req) diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index f8c73d3c..1ccbb035 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -9,8 +9,8 @@ import ( func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*godo.Firewall, error) { req := &godo.FirewallRequest{ - Name: fmt.Sprintf("%s-firewall", p.petriTag), - Tags: []string{p.petriTag}, + Name: fmt.Sprintf("%s-firewall", p.state.petriTag), + Tags: []string{p.state.petriTag}, OutboundRules: []godo.OutboundRule{ { Protocol: "tcp", @@ -39,7 +39,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "tcp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.petriTag}, + Tags: []string{p.state.petriTag}, Addresses: allowedIPs, }, }, @@ -47,7 +47,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "udp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.petriTag}, + Tags: []string{p.state.petriTag}, Addresses: allowedIPs, }, }, diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index ad791464..9e539ec3 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -26,6 +26,10 @@ const ( type ProviderState struct { TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state Name string `json:"name"` + petriTag string + userIPs []string + sshKeyPair *SSHKeyPair + firewallID string } type Provider struct { @@ -33,12 +37,7 @@ type Provider struct { stateMu sync.Mutex logger *zap.Logger - name string doClient DoClient - petriTag string - userIPs []string - sshKeyPair *SSHKeyPair - firewallID string dockerClients map[string]DockerClient // map of droplet ip address to docker clients } @@ -73,16 +72,18 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName digitalOceanProvider := &Provider{ logger: logger.Named("digitalocean_provider"), - name: providerName, doClient: doClient, - petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), - userIPs: userIPs, - sshKeyPair: sshKeyPair, dockerClients: dockerClients, - state: &ProviderState{TaskStates: make(map[int]*TaskState)}, + state: &ProviderState{ + TaskStates: make(map[int]*TaskState), + userIPs: userIPs, + Name: providerName, + sshKeyPair: sshKeyPair, + petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), + }, } - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) + _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.state.petriTag) if err != nil { return nil, err } @@ -92,7 +93,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName return nil, fmt.Errorf("failed to create firewall: %w", err) } - digitalOceanProvider.firewallID = firewall.ID + digitalOceanProvider.state.firewallID = firewall.ID //TODO(Zygimantass): TOCTOU issue if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { @@ -163,7 +164,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin Tty: false, Hostname: definition.Name, Labels: map[string]string{ - providerLabelName: p.name, + providerLabelName: p.state.Name, }, Env: convertEnvMapToList(definition.Environment), }, &container.HostConfig{ @@ -185,7 +186,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin Name: definition.Name, Definition: definition, Status: provider.TASK_STOPPED, - ProviderName: p.name, + ProviderName: p.state.Name, } p.stateMu.Lock() @@ -196,7 +197,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin return &Task{ state: taskState, provider: p, - sshKeyPair: p.sshKeyPair, + sshKeyPair: p.state.sshKeyPair, logger: p.logger.With(zap.String("task", definition.Name)), doClient: p.doClient, dockerClient: p.dockerClients[ip], @@ -271,7 +272,7 @@ func (p *Provider) Teardown(ctx context.Context) error { } func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.DeleteDropletByTag(ctx, p.petriTag) + res, err := p.doClient.DeleteDropletByTag(ctx, p.state.petriTag) if err != nil { return err } @@ -284,7 +285,7 @@ func (p *Provider) teardownTasks(ctx context.Context) error { } func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.DeleteFirewall(ctx, p.firewallID) + res, err := p.doClient.DeleteFirewall(ctx, p.state.firewallID) if err != nil { return err } @@ -297,7 +298,7 @@ func (p *Provider) teardownFirewall(ctx context.Context) error { } func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.sshKeyPair.Fingerprint) + res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.state.sshKeyPair.Fingerprint) if err != nil { return err } @@ -310,7 +311,7 @@ func (p *Provider) teardownSSHKey(ctx context.Context) error { } func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.DeleteTag(ctx, p.petriTag) + res, err := p.doClient.DeleteTag(ctx, p.state.petriTag) if err != nil { return err } diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index 0b604dee..a94220cf 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -96,7 +96,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { } func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { - req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.petriTag)} + req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.state.petriTag)} key, res, err := p.doClient.CreateKey(ctx, req) if err != nil { From d527c0b6ac44d8015069d18140c7fa3c50e08d04 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 16 Jan 2025 21:40:57 +0200 Subject: [PATCH 14/50] lint --- core/provider/digitalocean/docker.go | 10 ++++----- core/provider/docker/provider_test.go | 31 +++++++++++---------------- 2 files changed, 17 insertions(+), 24 deletions(-) diff --git a/core/provider/digitalocean/docker.go b/core/provider/digitalocean/docker.go index 10a6597a..5b6c1ab6 100644 --- a/core/provider/digitalocean/docker.go +++ b/core/provider/digitalocean/docker.go @@ -22,9 +22,9 @@ type DockerClient interface { ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) - ContainerExecCreate(ctx context.Context, container string, config types.ExecConfig) (types.IDResponse, error) - ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) - ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) + ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) + ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) + ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error Close() error } @@ -77,8 +77,8 @@ func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container return d.client.ContainerExecCreate(ctx, container, options) } -func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecAttachOptions) (types.HijackedResponse, error) { - return d.client.ContainerExecAttach(ctx, execID, config) +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, options container.ExecStartOptions) (types.HijackedResponse, error) { + return d.client.ContainerExecAttach(ctx, execID, options) } func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index 837182d6..cc4d7df5 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -17,7 +17,6 @@ import ( "github.com/skip-mev/petri/core/v3/provider" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "go.uber.org/zap" ) const idAlphabet = "abcdefghijklqmnoqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" @@ -49,9 +48,9 @@ func TestCreateProviderDuplicateNetwork(t *testing.T) { p1, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func(ctx context.Context, p provider.ProviderI) { - require.NoError(t, p.Teardown(ctx)) - }(ctx, p1) + defer func() { + require.NoError(t, p1.Teardown(context.Background())) + }() p2, err := docker.CreateProvider(ctx, logger, providerName) require.Error(t, err) @@ -68,9 +67,9 @@ func TestCreateProvider(t *testing.T) { p, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func(ctx context.Context, p provider.ProviderI) { + defer func() { require.NoError(t, p.Teardown(ctx)) - }(ctx, p) + }() state := p.GetState() assert.Equal(t, providerName, state.Name) @@ -108,11 +107,9 @@ func TestCreateTask(t *testing.T) { p, err := docker.CreateProvider(context.Background(), logger, providerName) require.NoError(t, err) - defer p.Teardown(context.Background()) - - defer func(ctx context.Context, p provider.ProviderI) { + defer func() { require.NoError(t, p.Teardown(ctx)) - }(ctx, p) + }() tests := []struct { name string @@ -185,11 +182,9 @@ func TestConcurrentTaskCreation(t *testing.T) { p, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer p.Teardown(ctx) - - defer func(ctx context.Context, p provider.ProviderI) { + defer func() { require.NoError(t, p.Teardown(ctx)) - }(ctx, p) + }() numTasks := 10 var wg sync.WaitGroup @@ -253,11 +248,9 @@ func TestProviderSerialization(t *testing.T) { p1, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer p1.Teardown(ctx) - - defer func(ctx context.Context, p provider.ProviderI) { - require.NoError(t, p.Teardown(ctx)) - }(ctx, p1) + defer func() { + require.NoError(t, p1.Teardown(ctx)) + }() _, err = p1.CreateTask(ctx, provider.TaskDefinition{ Name: fmt.Sprintf("%s-test-task", providerName), From 99eaf18a313634b9cacabc0d4dcfa289ed496937 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 16 Jan 2025 21:54:13 +0200 Subject: [PATCH 15/50] docker lint --- core/provider/docker/task.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index f1e30b0d..5a5df334 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -98,10 +98,6 @@ func (t *Task) Destroy(ctx context.Context) error { return nil } -func (t *Task) ensure(_ context.Context) error { - return nil -} - func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { t.provider.logger.Debug("getting external address", zap.String("id", t.state.Id)) From eab035569afebb0ee6ea8a3bab2c61ee6373f4ea Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 04:41:46 +0200 Subject: [PATCH 16/50] initialize missing fields after deserialization --- core/provider/digitalocean/provider.go | 32 +++++++++++++++++++++ core/provider/digitalocean/provider_test.go | 19 ++++++++++++ 2 files changed, 51 insertions(+) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 9e539ec3..eb77bf69 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -250,9 +250,41 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas state: &taskState, } + if err := p.initializeDeserializedTask(ctx, task); err != nil { + return nil, err + } + return task, nil } +func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) error { + task.logger = p.logger.With(zap.String("task", task.state.Name)) + task.sshKeyPair = p.state.sshKeyPair + task.doClient = p.doClient + task.provider = p + + droplet, err := task.getDroplet(ctx) + if err != nil { + return fmt.Errorf("failed to get droplet for task initialization: %w", err) + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return fmt.Errorf("failed to get droplet IP: %w", err) + } + + if p.dockerClients[ip] == nil { + client, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + return fmt.Errorf("failed to create docker client: %w", err) + } + p.dockerClients[ip] = client + } + + task.dockerClient = p.dockerClients[ip] + return nil +} + func (p *Provider) Teardown(ctx context.Context) error { p.logger.Info("tearing down DigitalOcean provider") diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 1ce42aa7..203d3664 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -207,6 +207,20 @@ func TestSerializeAndRestore(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, taskData) + mockDO.On("GetDroplet", ctx, 123).Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + deserializedTask, err := p.DeserializeTask(ctx, taskData) assert.NoError(t, err) assert.NotNil(t, deserializedTask) @@ -223,6 +237,11 @@ func TestSerializeAndRestore(t *testing.T) { } assert.Equal(t, t1.state, t2.state) + assert.NotNil(t, t2.logger) + assert.NotNil(t, t2.sshKeyPair) + assert.NotNil(t, t2.doClient) + assert.NotNil(t, t2.dockerClient) + assert.NotNil(t, t2.provider) mockDO.AssertExpectations(t) mockDocker.AssertExpectations(t) From 0e8104ffe5a01faea21e6fc77b3daae9f508db16 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 19:24:59 +0200 Subject: [PATCH 17/50] restore provider functionality --- core/provider/digitalocean/droplet.go | 6 +- core/provider/digitalocean/firewall.go | 8 +- core/provider/digitalocean/provider.go | 82 +++++++--- core/provider/digitalocean/provider_test.go | 163 +++++++++++++++++--- core/provider/digitalocean/ssh.go | 2 +- 5 files changed, 208 insertions(+), 53 deletions(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index c2f4a9a9..78e875b0 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -42,7 +42,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } req := &godo.DropletCreateRequest{ - Name: fmt.Sprintf("%s-%s", p.state.petriTag, definition.Name), + Name: fmt.Sprintf("%s-%s", p.state.PetriTag, definition.Name), Region: doConfig["region"], Size: doConfig["size"], Image: godo.DropletCreateImage{ @@ -50,10 +50,10 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe }, SSHKeys: []godo.DropletCreateSSHKey{ { - Fingerprint: p.state.sshKeyPair.Fingerprint, + Fingerprint: p.state.SSHKeyPair.Fingerprint, }, }, - Tags: []string{p.state.petriTag}, + Tags: []string{p.state.PetriTag}, } droplet, res, err := p.doClient.CreateDroplet(ctx, req) diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index 1ccbb035..a5cb0448 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -9,8 +9,8 @@ import ( func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*godo.Firewall, error) { req := &godo.FirewallRequest{ - Name: fmt.Sprintf("%s-firewall", p.state.petriTag), - Tags: []string{p.state.petriTag}, + Name: fmt.Sprintf("%s-firewall", p.state.PetriTag), + Tags: []string{p.state.PetriTag}, OutboundRules: []godo.OutboundRule{ { Protocol: "tcp", @@ -39,7 +39,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "tcp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.state.petriTag}, + Tags: []string{p.state.PetriTag}, Addresses: allowedIPs, }, }, @@ -47,7 +47,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "udp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.state.petriTag}, + Tags: []string{p.state.PetriTag}, Addresses: allowedIPs, }, }, diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index eb77bf69..222e03c1 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -3,6 +3,7 @@ package digitalocean import ( "context" "encoding/json" + "errors" "fmt" "strings" "sync" @@ -26,10 +27,10 @@ const ( type ProviderState struct { TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state Name string `json:"name"` - petriTag string - userIPs []string - sshKeyPair *SSHKeyPair - firewallID string + PetriTag string `json:"petri_tag"` + UserIPs []string `json:"user_ips"` + SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` + FirewallID string `json:"firewall_id"` } type Provider struct { @@ -76,14 +77,14 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName dockerClients: dockerClients, state: &ProviderState{ TaskStates: make(map[int]*TaskState), - userIPs: userIPs, + UserIPs: userIPs, Name: providerName, - sshKeyPair: sshKeyPair, - petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), + SSHKeyPair: sshKeyPair, + PetriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), }, } - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.state.petriTag) + _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.state.PetriTag) if err != nil { return nil, err } @@ -93,7 +94,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName return nil, fmt.Errorf("failed to create firewall: %w", err) } - digitalOceanProvider.state.firewallID = firewall.ID + digitalOceanProvider.state.FirewallID = firewall.ID //TODO(Zygimantass): TOCTOU issue if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { @@ -197,7 +198,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin return &Task{ state: taskState, provider: p, - sshKeyPair: p.state.sshKeyPair, + sshKeyPair: p.state.SSHKeyPair, logger: p.logger.With(zap.String("task", definition.Name)), doClient: p.doClient, dockerClient: p.dockerClients[ip], @@ -213,13 +214,54 @@ func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { return bz, err } -func (p *Provider) DeserializeProvider(context.Context) ([]byte, error) { - p.stateMu.Lock() - defer p.stateMu.Unlock() +func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]DockerClient) (*Provider, error) { + if doClient == nil && token == "" { + return nil, errors.New("a valid token or digital ocean client must be passed when restoring the provider") + } + var providerState ProviderState - bz, err := json.Marshal(p.state) + err := json.Unmarshal(state, &providerState) + if err != nil { + return nil, err + } - return bz, err + if dockerClients == nil { + dockerClients = make(map[string]DockerClient) + } + + digitalOceanProvider := &Provider{ + state: &providerState, + dockerClients: dockerClients, + logger: zap.L().Named("digitalocean_provider"), + } + + if doClient != nil { + digitalOceanProvider.doClient = doClient + } else { + digitalOceanProvider.doClient = NewGodoClient(token) + } + + for _, taskState := range providerState.TaskStates { + droplet, _, err := digitalOceanProvider.doClient.GetDroplet(ctx, taskState.ID) + if err != nil { + return nil, fmt.Errorf("failed to get droplet for task state: %w", err) + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return nil, fmt.Errorf("failed to get droplet IP: %w", err) + } + + if digitalOceanProvider.dockerClients[ip] == nil { + client, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + if err != nil { + return nil, fmt.Errorf("failed to create docker client: %w", err) + } + digitalOceanProvider.dockerClients[ip] = client + } + } + + return digitalOceanProvider, nil } func (p *Provider) SerializeTask(ctx context.Context, task provider.TaskI) ([]byte, error) { @@ -259,7 +301,7 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) error { task.logger = p.logger.With(zap.String("task", task.state.Name)) - task.sshKeyPair = p.state.sshKeyPair + task.sshKeyPair = p.state.SSHKeyPair task.doClient = p.doClient task.provider = p @@ -304,7 +346,7 @@ func (p *Provider) Teardown(ctx context.Context) error { } func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.DeleteDropletByTag(ctx, p.state.petriTag) + res, err := p.doClient.DeleteDropletByTag(ctx, p.state.PetriTag) if err != nil { return err } @@ -317,7 +359,7 @@ func (p *Provider) teardownTasks(ctx context.Context) error { } func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.DeleteFirewall(ctx, p.state.firewallID) + res, err := p.doClient.DeleteFirewall(ctx, p.state.FirewallID) if err != nil { return err } @@ -330,7 +372,7 @@ func (p *Provider) teardownFirewall(ctx context.Context) error { } func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.state.sshKeyPair.Fingerprint) + res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.state.SSHKeyPair.Fingerprint) if err != nil { return err } @@ -343,7 +385,7 @@ func (p *Provider) teardownSSHKey(ctx context.Context) error { } func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.DeleteTag(ctx, p.state.petriTag) + res, err := p.doClient.DeleteTag(ctx, p.state.PetriTag) if err != nil { return err } diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 203d3664..246573be 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -67,25 +67,8 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) require.NoError(t, err) - mockDO.On("CreateDroplet", mock.Anything, mock.AnythingOfType("*godo.DropletCreateRequest")). - Return(&godo.Droplet{ - ID: 123, - Name: "test-droplet", - Status: "active", - Networks: &godo.Networks{ - V4: []godo.NetworkV4{ - { - IPAddress: "10.0.0.1", - Type: "public", - }, - }, - }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - - mockDO.On("GetDroplet", ctx, 123).Return(&godo.Droplet{ - ID: 123, - Name: "test-droplet", - Status: "active", + droplet := &godo.Droplet{ + ID: 123, Networks: &godo.Networks{ V4: []godo.NetworkV4{ { @@ -94,7 +77,36 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC }, }, }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + Status: "active", + } + + var callCount int + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("TagResource", ctx, mock.Anything, mock.Anything).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(func(ctx context.Context, id int) *godo.Droplet { + if callCount == 0 { + callCount++ + return &godo.Droplet{ + ID: id, + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + Status: "new", + } + } + return droplet + }, func(ctx context.Context, id int) *godo.Response { + return &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}} + }, func(ctx context.Context, id int) error { + return nil + }).Maybe() + + mockDO.On("DeleteDropletByID", ctx, droplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() return p, mockDO, mockDocker } @@ -177,14 +189,10 @@ func TestCreateTask_MissingRegion(t *testing.T) { assert.Nil(t, task) } -func TestSerializeAndRestore(t *testing.T) { +func TestSerializeAndRestoreTask(t *testing.T) { ctx := context.Background() p, mockDO, mockDocker := setupTestProvider(t, ctx) - providerData, err := p.SerializeProvider(ctx) - assert.NoError(t, err) - assert.NotNil(t, providerData) - taskDef := provider.TaskDefinition{ Name: "test-task", Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, @@ -448,3 +456,108 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { client.(*mocks.DockerClient).AssertExpectations(t) } } + +func TestProviderSerialization(t *testing.T) { + ctx := context.Background() + mockDO := mocks.NewDoClient(t) + mockDocker := mocks.NewDockerClient(t) + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "petri-droplet-test"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + + mockDockerClients := map[string]DockerClient{ + "10.0.0.1": mockDocker, + } + + p1, err := NewProviderWithClient(ctx, zap.NewExample(), "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + droplet := &godo.Droplet{ + ID: 123, + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + Status: "active", + } + + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("TagResource", ctx, mock.Anything, mock.Anything).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() + mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", ctx, "ubuntu:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { + return config.Image == "ubuntu:latest" && + config.Hostname == "test-task" && + len(config.Labels) > 0 && + config.Labels[providerLabelName] == "test-provider" + }), mock.MatchedBy(func(hostConfig *container.HostConfig) bool { + return len(hostConfig.Mounts) == 1 && + hostConfig.Mounts[0].Target == "/data" && + hostConfig.NetworkMode == "host" + }), mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: "test-container"}, nil) + + _, err = p1.CreateTask(ctx, provider.TaskDefinition{ + Name: "test-task", + ContainerName: "test-container", + Image: provider.ImageDefinition{ + Image: "ubuntu:latest", + UID: "1000", + GID: "1000", + }, + DataDir: "/data", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + }) + require.NoError(t, err) + + state1 := p1.state + serialized, err := p1.SerializeProvider(ctx) + require.NoError(t, err) + + mockDO2 := mocks.NewDoClient(t) + mockDO2.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + + mockDocker2 := mocks.NewDockerClient(t) + mockDocker2.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() + + mockDockerClients2 := map[string]DockerClient{ + "10.0.0.1": mockDocker2, + } + + p2, err := RestoreProvider(ctx, "test-token", serialized, mockDO2, mockDockerClients2) + require.NoError(t, err) + + state2 := p2.state + assert.Equal(t, state1.Name, state2.Name) + assert.Equal(t, state1.PetriTag, state2.PetriTag) + assert.Equal(t, state1.FirewallID, state2.FirewallID) + assert.Equal(t, len(state1.TaskStates), len(state2.TaskStates)) + + for id, task1 := range state1.TaskStates { + task2, exists := state2.TaskStates[id] + assert.True(t, exists) + assert.Equal(t, task1.Name, task2.Name) + assert.Equal(t, task1.Status, task2.Status) + + if configMap, ok := task2.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { + doConfig := make(DigitalOceanTaskConfig) + for k, v := range configMap { + doConfig[k] = v.(string) + } + task2.Definition.ProviderSpecificConfig = doConfig + } + assert.Equal(t, task1.Definition, task2.Definition) + } +} diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index a94220cf..78d34909 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -96,7 +96,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { } func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { - req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.state.petriTag)} + req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.state.PetriTag)} key, res, err := p.doClient.CreateKey(ctx, req) if err != nil { From d4a7ef20675327dda5bbbb887556f942eb490979 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Mon, 20 Jan 2025 20:17:52 +0200 Subject: [PATCH 18/50] remove docker changes --- core/provider/docker/network.go | 2 -- core/provider/docker/provider.go | 1 - core/provider/docker/provider_test.go | 34 ++++++++++++++------------- core/provider/docker/task.go | 6 ----- 4 files changed, 18 insertions(+), 25 deletions(-) diff --git a/core/provider/docker/network.go b/core/provider/docker/network.go index 8392380a..f18d42cf 100644 --- a/core/provider/docker/network.go +++ b/core/provider/docker/network.go @@ -141,8 +141,6 @@ func (p *Provider) nextAvailableIP() (string, error) { return "", err } - p.state.AllocatedIPs = append(p.state.AllocatedIPs, ip.To4().String()) - return ip.String(), nil } diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index 80bba660..f0636305 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -224,7 +224,6 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin logger.Debug("creating container", zap.String("name", definition.Name), zap.String("image", definition.Image.Image)) - // network map is volatile, so we need to mutex update it ip, err := p.nextAvailableIP() if err != nil { return nil, err diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index cc4d7df5..d8207737 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -3,16 +3,15 @@ package docker_test import ( "context" "fmt" - "sync" - "testing" - "time" - "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/skip-mev/petri/core/v3/provider/docker" "go.uber.org/zap/zaptest" + "sync" + "testing" + "time" "github.com/skip-mev/petri/core/v3/provider" "github.com/stretchr/testify/assert" @@ -48,9 +47,9 @@ func TestCreateProviderDuplicateNetwork(t *testing.T) { p1, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func() { - require.NoError(t, p1.Teardown(context.Background())) - }() + defer func(ctx context.Context, p provider.ProviderI) { + require.NoError(t, p.Teardown(ctx)) + }(ctx, p1) p2, err := docker.CreateProvider(ctx, logger, providerName) require.Error(t, err) @@ -67,9 +66,9 @@ func TestCreateProvider(t *testing.T) { p, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func() { + defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) - }() + }(ctx, p) state := p.GetState() assert.Equal(t, providerName, state.Name) @@ -107,9 +106,10 @@ func TestCreateTask(t *testing.T) { p, err := docker.CreateProvider(context.Background(), logger, providerName) require.NoError(t, err) - defer func() { + + defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) - }() + }(ctx, p) tests := []struct { name string @@ -182,9 +182,10 @@ func TestConcurrentTaskCreation(t *testing.T) { p, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func() { + + defer func(ctx context.Context, p provider.ProviderI) { require.NoError(t, p.Teardown(ctx)) - }() + }(ctx, p) numTasks := 10 var wg sync.WaitGroup @@ -248,9 +249,10 @@ func TestProviderSerialization(t *testing.T) { p1, err := docker.CreateProvider(ctx, logger, providerName) require.NoError(t, err) - defer func() { - require.NoError(t, p1.Teardown(ctx)) - }() + + defer func(ctx context.Context, p provider.ProviderI) { + require.NoError(t, p.Teardown(ctx)) + }(ctx, p1) _, err = p1.CreateTask(ctx, provider.TaskDefinition{ Name: fmt.Sprintf("%s-test-task", providerName), diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 5a5df334..ea569f8c 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -154,8 +154,6 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { return provider.TASK_STATUS_UNDEFINED, err } - fmt.Println(containerJSON.State.Status) - switch state := containerJSON.State.Status; state { case "created": return provider.TASK_STOPPED, nil @@ -176,10 +174,6 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { return provider.TASK_STATUS_UNDEFINED, nil } -func (t *Task) Initialize(ctx context.Context) error { - return nil -} - func (t *Task) Modify(ctx context.Context, td provider.TaskDefinition) error { panic("unimplemented") } From d80df8e796ed97d22086c05d9707305c6be7e79e Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 00:12:23 +0200 Subject: [PATCH 19/50] pr feedback --- core/provider/digitalocean/client.go | 108 +++++--- core/provider/digitalocean/droplet.go | 47 +--- .../examples/digitalocean_simapp.go | 35 +-- .../digitalocean/files/docker-cloud-init.yaml | 37 --- core/provider/digitalocean/firewall.go | 20 +- .../digitalocean/mocks/do_client_mock.go | 236 +++++------------- core/provider/digitalocean/provider.go | 146 ++++++----- core/provider/digitalocean/provider_test.go | 85 +++---- core/provider/digitalocean/ssh.go | 14 +- core/provider/digitalocean/tag.go | 25 -- core/provider/digitalocean/task.go | 171 ++++++------- core/provider/digitalocean/task_test.go | 100 +++++--- 12 files changed, 420 insertions(+), 604 deletions(-) delete mode 100644 core/provider/digitalocean/files/docker-cloud-init.yaml delete mode 100644 core/provider/digitalocean/tag.go diff --git a/core/provider/digitalocean/client.go b/core/provider/digitalocean/client.go index 753abc6f..9375dc34 100644 --- a/core/provider/digitalocean/client.go +++ b/core/provider/digitalocean/client.go @@ -2,6 +2,7 @@ package digitalocean import ( "context" + "fmt" "github.com/digitalocean/godo" ) @@ -9,23 +10,23 @@ import ( // DoClient defines the interface for DigitalOcean API operations used by the provider type DoClient interface { // Droplet operations - CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) - GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) - DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) - DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) + CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) + GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) + DeleteDropletByTag(ctx context.Context, tag string) error + DeleteDropletByID(ctx context.Context, id int) error // Firewall operations - CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) - DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) + CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) + DeleteFirewall(ctx context.Context, firewallID string) error // SSH Key operations - CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) - DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) - GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) + CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) + DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error + GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) // Tag operations - CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) - DeleteTag(ctx context.Context, tag string) (*godo.Response, error) + CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) + DeleteTag(ctx context.Context, tag string) error } // godoClient implements the DoClient interface using the actual godo.Client @@ -37,50 +38,91 @@ func NewGodoClient(token string) DoClient { return &godoClient{Client: godo.NewFromToken(token)} } +func checkResponse(res *godo.Response, err error) error { + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + // Droplet operations -func (c *godoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { - return c.Droplets.Create(ctx, req) +func (c *godoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) { + droplet, res, err := c.Droplets.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return droplet, nil } -func (c *godoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { - return c.Droplets.Get(ctx, dropletID) +func (c *godoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) { + droplet, res, err := c.Droplets.Get(ctx, dropletID) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return droplet, nil } -func (c *godoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { - return c.Droplets.DeleteByTag(ctx, tag) +func (c *godoClient) DeleteDropletByTag(ctx context.Context, tag string) error { + res, err := c.Droplets.DeleteByTag(ctx, tag) + return checkResponse(res, err) } -func (c *godoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { - return c.Droplets.Delete(ctx, id) +func (c *godoClient) DeleteDropletByID(ctx context.Context, id int) error { + res, err := c.Droplets.Delete(ctx, id) + return checkResponse(res, err) } // Firewall operations -func (c *godoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { - return c.Firewalls.Create(ctx, req) +func (c *godoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) { + firewall, res, err := c.Firewalls.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return firewall, nil } -func (c *godoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { - return c.Firewalls.Delete(ctx, firewallID) +func (c *godoClient) DeleteFirewall(ctx context.Context, firewallID string) error { + res, err := c.Firewalls.Delete(ctx, firewallID) + return checkResponse(res, err) } // SSH Key operations -func (c *godoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { - return c.Keys.Create(ctx, req) +func (c *godoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) { + key, res, err := c.Keys.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return key, nil } -func (c *godoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { - return c.Keys.DeleteByFingerprint(ctx, fingerprint) +func (c *godoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error { + res, err := c.Keys.DeleteByFingerprint(ctx, fingerprint) + return checkResponse(res, err) } -func (c *godoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { - return c.Keys.GetByFingerprint(ctx, fingerprint) +func (c *godoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) { + key, res, err := c.Keys.GetByFingerprint(ctx, fingerprint) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return key, nil } // Tag operations -func (c *godoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { - return c.Tags.Create(ctx, req) +func (c *godoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) { + tag, res, err := c.Tags.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return tag, nil } -func (c *godoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { - return c.Tags.Delete(ctx, tag) +func (c *godoClient) DeleteTag(ctx context.Context, tag string) error { + res, err := c.Tags.Delete(ctx, tag) + return checkResponse(res, err) } diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 78e875b0..35a6e8bd 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -19,11 +19,6 @@ import ( _ "embed" ) -// nolint -// -//go:embed files/docker-cloud-init.yaml -var dockerCloudInit string - func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDefinition) (*godo.Droplet, error) { if err := definition.ValidateBasic(); err != nil { return nil, fmt.Errorf("failed to validate task definition: %w", err) @@ -41,8 +36,9 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return nil, fmt.Errorf("failed to parse image ID: %w", err) } + state := p.GetState() req := &godo.DropletCreateRequest{ - Name: fmt.Sprintf("%s-%s", p.state.PetriTag, definition.Name), + Name: fmt.Sprintf("%s-%s", state.PetriTag, definition.Name), Region: doConfig["region"], Size: doConfig["size"], Image: godo.DropletCreateImage{ @@ -50,25 +46,21 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe }, SSHKeys: []godo.DropletCreateSSHKey{ { - Fingerprint: p.state.SSHKeyPair.Fingerprint, + Fingerprint: state.SSHKeyPair.Fingerprint, }, }, - Tags: []string{p.state.PetriTag}, + Tags: []string{state.PetriTag}, } - droplet, res, err := p.doClient.CreateDroplet(ctx, req) + droplet, err := p.doClient.CreateDroplet(ctx, req) if err != nil { return nil, err } - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - start := time.Now() err = util.WaitForCondition(ctx, time.Second*600, time.Millisecond*300, func() (bool, error) { - d, _, err := p.doClient.GetDroplet(ctx, droplet.ID) + d, err := p.doClient.GetDroplet(ctx, droplet.ID) if err != nil { return false, err } @@ -83,7 +75,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } if p.dockerClients[ip] == nil { - dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { p.logger.Error("failed to create docker client", zap.Error(err)) return false, err @@ -113,34 +105,15 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe func (t *Task) deleteDroplet(ctx context.Context) error { droplet, err := t.getDroplet(ctx) - if err != nil { return err } - res, err := t.doClient.DeleteDropletByID(ctx, droplet.ID) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil + return t.doClient.DeleteDropletByID(ctx, droplet.ID) } func (t *Task) getDroplet(ctx context.Context) (*godo.Droplet, error) { - droplet, res, err := t.doClient.GetDroplet(ctx, t.state.ID) - if err != nil { - return nil, err - } - - if res.StatusCode < 200 || res.StatusCode > 299 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return droplet, nil + return t.doClient.GetDroplet(ctx, t.GetState().ID) } func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { @@ -161,7 +134,7 @@ func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.C return nil, err } - parsedSSHKey, err := ssh.ParsePrivateKey([]byte(t.sshKeyPair.PrivateKey)) + parsedSSHKey, err := ssh.ParsePrivateKey([]byte(t.GetState().SSHKeyPair.PrivateKey)) if err != nil { return nil, err } diff --git a/core/provider/digitalocean/examples/digitalocean_simapp.go b/core/provider/digitalocean/examples/digitalocean_simapp.go index a35b2c5e..cc00c298 100644 --- a/core/provider/digitalocean/examples/digitalocean_simapp.go +++ b/core/provider/digitalocean/examples/digitalocean_simapp.go @@ -2,16 +2,14 @@ package main import ( "context" - "os" - "github.com/cosmos/cosmos-sdk/crypto/hd" - "github.com/skip-mev/petri/core/v2/types" - petritypes "github.com/skip-mev/petri/core/v2/types" "github.com/skip-mev/petri/cosmos/v2/chain" "github.com/skip-mev/petri/cosmos/v2/node" + "os" "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean" + petritypes "github.com/skip-mev/petri/core/v2/types" "go.uber.org/zap" ) @@ -56,17 +54,22 @@ func main() { HomeDir: "/gaia", CoinType: "118", ChainId: "stake-1", - NodeCreator: node.CreateNode, - NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig petritypes.NodeConfig) provider.TaskDefinition { - doConfig := digitalocean.DigitalOceanTaskConfig{ - "size": "s-2vcpu-4gb", - "region": "ams3", - "image_id": imageID, - } - def.ProviderSpecificConfig = doConfig - return def + } + + chainOptions := petritypes.ChainOptions{ + NodeCreator: node.CreateNode, + NodeOptions: petritypes.NodeOptions{ + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig petritypes.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": imageID, + } + def.ProviderSpecificConfig = doConfig + return def + }, }, - WalletConfig: types.WalletConfig{ + WalletConfig: petritypes.WalletConfig{ SigningAlgorithm: string(hd.Secp256k1.Name()), Bech32Prefix: "cosmos", HDPath: hd.CreateHDPath(118, 0, 0), @@ -76,13 +79,13 @@ func main() { } logger.Info("Creating chain") - cosmosChain, err := chain.CreateChain(ctx, logger, doProvider, chainConfig) + cosmosChain, err := chain.CreateChain(ctx, logger, doProvider, chainConfig, chainOptions) if err != nil { logger.Fatal("failed to create chain", zap.Error(err)) } logger.Info("Initializing chain") - err = cosmosChain.Init(ctx) + err = cosmosChain.Init(ctx, chainOptions) if err != nil { logger.Fatal("failed to initialize chain", zap.Error(err)) } diff --git a/core/provider/digitalocean/files/docker-cloud-init.yaml b/core/provider/digitalocean/files/docker-cloud-init.yaml deleted file mode 100644 index eaf85b03..00000000 --- a/core/provider/digitalocean/files/docker-cloud-init.yaml +++ /dev/null @@ -1,37 +0,0 @@ -#cloud-config - -package_update: true -package_upgrade: true - -# create the docker group -groups: - - docker - -# Setup Docker daemon to listen on tcp and unix socket -write_files: - - path: /etc/sysctl.d/enabled_ipv4_forwarding.conf - content: | - net.ipv4.conf.all.forwarding=1 - - path: /etc/docker/daemon.json - content: | - { - "hosts": ["unix:///var/run/docker.sock", "tcp://0.0.0.0:2375"] - } - owner: root:root - permissions: '0644' - - path: /etc/systemd/system/docker.service.d/override.conf - content: | - [Service] - ExecStart= - ExecStart=/usr/bin/dockerd --containerd=/run/containerd/containerd.sock --tls=false - owner: root:root - - -# Create a directory for Docker volumes -runcmd: - - curl -fsSL https://get.docker.com | sh - - mkdir /docker_volumes - - chmod 755 /docker_volumes - - chown root:docker /docker_volumes - - systemctl daemon-reload - - systemctl restart docker \ No newline at end of file diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index a5cb0448..ec819adc 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -8,9 +8,10 @@ import ( ) func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*godo.Firewall, error) { + state := p.GetState() req := &godo.FirewallRequest{ - Name: fmt.Sprintf("%s-firewall", p.state.PetriTag), - Tags: []string{p.state.PetriTag}, + Name: fmt.Sprintf("%s-firewall", state.PetriTag), + Tags: []string{state.PetriTag}, OutboundRules: []godo.OutboundRule{ { Protocol: "tcp", @@ -39,7 +40,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "tcp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.state.PetriTag}, + Tags: []string{state.PetriTag}, Addresses: allowedIPs, }, }, @@ -47,21 +48,12 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "udp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.state.PetriTag}, + Tags: []string{state.PetriTag}, Addresses: allowedIPs, }, }, }, } - firewall, res, err := p.doClient.CreateFirewall(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return firewall, nil + return p.doClient.CreateFirewall(ctx, req) } diff --git a/core/provider/digitalocean/mocks/do_client_mock.go b/core/provider/digitalocean/mocks/do_client_mock.go index a94deb7c..302823c2 100644 --- a/core/provider/digitalocean/mocks/do_client_mock.go +++ b/core/provider/digitalocean/mocks/do_client_mock.go @@ -16,7 +16,7 @@ type DoClient struct { } // CreateDroplet provides a mock function with given fields: ctx, req -func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error) { +func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { @@ -24,9 +24,8 @@ func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRe } var r0 *godo.Droplet - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) (*godo.Droplet, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) (*godo.Droplet, error)); ok { return rf(ctx, req) } if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) *godo.Droplet); ok { @@ -37,25 +36,17 @@ func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRe } } - if rf, ok := ret.Get(1).(func(context.Context, *godo.DropletCreateRequest) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, *godo.DropletCreateRequest) error); ok { r1 = rf(ctx, req) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, *godo.DropletCreateRequest) error); ok { - r2 = rf(ctx, req) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // CreateFirewall provides a mock function with given fields: ctx, req -func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error) { +func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { @@ -63,9 +54,8 @@ func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallReques } var r0 *godo.Firewall - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) (*godo.Firewall, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) (*godo.Firewall, error)); ok { return rf(ctx, req) } if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) *godo.Firewall); ok { @@ -76,25 +66,17 @@ func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallReques } } - if rf, ok := ret.Get(1).(func(context.Context, *godo.FirewallRequest) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, *godo.FirewallRequest) error); ok { r1 = rf(ctx, req) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, *godo.FirewallRequest) error); ok { - r2 = rf(ctx, req) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // CreateKey provides a mock function with given fields: ctx, req -func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error) { +func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { @@ -102,9 +84,8 @@ func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) ( } var r0 *godo.Key - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) (*godo.Key, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) (*godo.Key, error)); ok { return rf(ctx, req) } if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) *godo.Key); ok { @@ -115,25 +96,17 @@ func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) ( } } - if rf, ok := ret.Get(1).(func(context.Context, *godo.KeyCreateRequest) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, *godo.KeyCreateRequest) error); ok { r1 = rf(ctx, req) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, *godo.KeyCreateRequest) error); ok { - r2 = rf(ctx, req) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // CreateTag provides a mock function with given fields: ctx, req -func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error) { +func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) { ret := _m.Called(ctx, req) if len(ret) == 0 { @@ -141,9 +114,8 @@ func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) ( } var r0 *godo.Tag - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) (*godo.Tag, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) (*godo.Tag, error)); ok { return rf(ctx, req) } if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) *godo.Tag); ok { @@ -154,175 +126,107 @@ func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) ( } } - if rf, ok := ret.Get(1).(func(context.Context, *godo.TagCreateRequest) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, *godo.TagCreateRequest) error); ok { r1 = rf(ctx, req) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, *godo.TagCreateRequest) error); ok { - r2 = rf(ctx, req) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // DeleteDropletByID provides a mock function with given fields: ctx, id -func (_m *DoClient) DeleteDropletByID(ctx context.Context, id int) (*godo.Response, error) { +func (_m *DoClient) DeleteDropletByID(ctx context.Context, id int) error { ret := _m.Called(ctx, id) if len(ret) == 0 { panic("no return value specified for DeleteDropletByID") } - var r0 *godo.Response - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Response, error)); ok { - return rf(ctx, id) - } - if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Response); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { r0 = rf(ctx, id) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*godo.Response) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { - r1 = rf(ctx, id) - } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // DeleteDropletByTag provides a mock function with given fields: ctx, tag -func (_m *DoClient) DeleteDropletByTag(ctx context.Context, tag string) (*godo.Response, error) { +func (_m *DoClient) DeleteDropletByTag(ctx context.Context, tag string) error { ret := _m.Called(ctx, tag) if len(ret) == 0 { panic("no return value specified for DeleteDropletByTag") } - var r0 *godo.Response - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { - return rf(ctx, tag) - } - if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, tag) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*godo.Response) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, tag) - } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // DeleteFirewall provides a mock function with given fields: ctx, firewallID -func (_m *DoClient) DeleteFirewall(ctx context.Context, firewallID string) (*godo.Response, error) { +func (_m *DoClient) DeleteFirewall(ctx context.Context, firewallID string) error { ret := _m.Called(ctx, firewallID) if len(ret) == 0 { panic("no return value specified for DeleteFirewall") } - var r0 *godo.Response - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { - return rf(ctx, firewallID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, firewallID) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*godo.Response) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, firewallID) - } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // DeleteKeyByFingerprint provides a mock function with given fields: ctx, fingerprint -func (_m *DoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Response, error) { +func (_m *DoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error { ret := _m.Called(ctx, fingerprint) if len(ret) == 0 { panic("no return value specified for DeleteKeyByFingerprint") } - var r0 *godo.Response - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { - return rf(ctx, fingerprint) - } - if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, fingerprint) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*godo.Response) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, fingerprint) - } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // DeleteTag provides a mock function with given fields: ctx, tag -func (_m *DoClient) DeleteTag(ctx context.Context, tag string) (*godo.Response, error) { +func (_m *DoClient) DeleteTag(ctx context.Context, tag string) error { ret := _m.Called(ctx, tag) if len(ret) == 0 { panic("no return value specified for DeleteTag") } - var r0 *godo.Response - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Response, error)); ok { - return rf(ctx, tag) - } - if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Response); ok { + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { r0 = rf(ctx, tag) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*godo.Response) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, tag) - } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // GetDroplet provides a mock function with given fields: ctx, dropletID -func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, *godo.Response, error) { +func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) { ret := _m.Called(ctx, dropletID) if len(ret) == 0 { @@ -330,9 +234,8 @@ func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Drople } var r0 *godo.Droplet - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Droplet, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Droplet, error)); ok { return rf(ctx, dropletID) } if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Droplet); ok { @@ -343,25 +246,17 @@ func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Drople } } - if rf, ok := ret.Get(1).(func(context.Context, int) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { r1 = rf(ctx, dropletID) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, int) error); ok { - r2 = rf(ctx, dropletID) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // GetKeyByFingerprint provides a mock function with given fields: ctx, fingerprint -func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, *godo.Response, error) { +func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) { ret := _m.Called(ctx, fingerprint) if len(ret) == 0 { @@ -369,9 +264,8 @@ func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) } var r0 *godo.Key - var r1 *godo.Response - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Key, *godo.Response, error)); ok { + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Key, error)); ok { return rf(ctx, fingerprint) } if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Key); ok { @@ -382,21 +276,13 @@ func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) } } - if rf, ok := ret.Get(1).(func(context.Context, string) *godo.Response); ok { + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, fingerprint) } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).(*godo.Response) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { - r2 = rf(ctx, fingerprint) - } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // NewDoClient creates a new instance of DoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 222e03c1..087a3ca7 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -7,9 +7,12 @@ import ( "fmt" "strings" "sync" + "time" + "github.com/digitalocean/godo" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/client" "go.uber.org/zap" @@ -21,7 +24,7 @@ var _ provider.ProviderI = (*Provider)(nil) const ( providerLabelName = "petri-provider" - sshPort = "2375" + dockerPort = "2375" ) type ProviderState struct { @@ -97,7 +100,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName digitalOceanProvider.state.FirewallID = firewall.ID //TODO(Zygimantass): TOCTOU issue - if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { + if key, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) if err != nil { if !strings.Contains(err.Error(), "422") { @@ -142,44 +145,60 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin p.logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) - if p.dockerClients[ip] == nil { - p.dockerClients[ip], err = NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + dockerClient := p.dockerClients[ip] + if dockerClient == nil { + dockerClient, err = NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, err } } - _, _, err = p.dockerClients[ip].ImageInspectWithRaw(ctx, definition.Image.Image) + _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) if err != nil { p.logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) - err = pullImage(ctx, p.dockerClients[ip], p.logger, definition.Image.Image) + err = pullImage(ctx, dockerClient, p.logger, definition.Image.Image) if err != nil { return nil, err } } - _, err = p.dockerClients[ip].ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, - Tty: false, - Hostname: definition.Name, - Labels: map[string]string{ - providerLabelName: p.state.Name, - }, - Env: convertEnvMapToList(definition.Environment), - }, &container.HostConfig{ - Mounts: []mount.Mount{ - { - Type: mount.TypeBind, - Source: "/docker_volumes", - Target: definition.DataDir, + state := p.GetState() + + err = util.WaitForCondition(ctx, 30*time.Second, 1*time.Second, func() (bool, error) { + _, err := dockerClient.ContainerCreate(ctx, &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Labels: map[string]string{ + providerLabelName: state.Name, }, - }, - NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) + Env: convertEnvMapToList(definition.Environment), + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: definition.DataDir, + }, + }, + NetworkMode: container.NetworkMode("host"), + }, nil, nil, definition.ContainerName) + + if err != nil { + if client.IsErrConnectionFailed(err) { + p.logger.Warn("connection failed while creating container, will retry", zap.Error(err)) + return false, nil + } + return false, err + } + + return true, nil + }) + if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create container after retries: %w", err) } taskState := &TaskState{ @@ -187,7 +206,8 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin Name: definition.Name, Definition: definition, Status: provider.TASK_STOPPED, - ProviderName: p.state.Name, + ProviderName: state.Name, + SSHKeyPair: state.SSHKeyPair, } p.stateMu.Lock() @@ -198,10 +218,9 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin return &Task{ state: taskState, provider: p, - sshKeyPair: p.state.SSHKeyPair, logger: p.logger.With(zap.String("task", definition.Name)), doClient: p.doClient, - dockerClient: p.dockerClients[ip], + dockerClient: dockerClient, }, nil } @@ -242,7 +261,7 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } for _, taskState := range providerState.TaskStates { - droplet, _, err := digitalOceanProvider.doClient.GetDroplet(ctx, taskState.ID) + droplet, err := digitalOceanProvider.doClient.GetDroplet(ctx, taskState.ID) if err != nil { return nil, fmt.Errorf("failed to get droplet for task state: %w", err) } @@ -253,11 +272,11 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } if digitalOceanProvider.dockerClients[ip] == nil { - client, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, fmt.Errorf("failed to create docker client: %w", err) } - digitalOceanProvider.dockerClients[ip] = client + digitalOceanProvider.dockerClients[ip] = dockerClient } } @@ -301,7 +320,6 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) error { task.logger = p.logger.With(zap.String("task", task.state.Name)) - task.sshKeyPair = p.state.SSHKeyPair task.doClient = p.doClient task.provider = p @@ -316,11 +334,11 @@ func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) e } if p.dockerClients[ip] == nil { - client, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, sshPort)) + dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return fmt.Errorf("failed to create docker client: %w", err) } - p.dockerClients[ip] = client + p.dockerClients[ip] = dockerClient } task.dockerClient = p.dockerClients[ip] @@ -346,55 +364,19 @@ func (p *Provider) Teardown(ctx context.Context) error { } func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.DeleteDropletByTag(ctx, p.state.PetriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil + return p.doClient.DeleteDropletByTag(ctx, p.GetState().PetriTag) } func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.DeleteFirewall(ctx, p.state.FirewallID) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil + return p.doClient.DeleteFirewall(ctx, p.GetState().FirewallID) } func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.state.SSHKeyPair.Fingerprint) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil + return p.doClient.DeleteKeyByFingerprint(ctx, p.GetState().SSHKeyPair.Fingerprint) } func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.DeleteTag(ctx, p.state.PetriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil + return p.doClient.DeleteTag(ctx, p.GetState().PetriTag) } func (p *Provider) removeTask(_ context.Context, taskID int) error { @@ -405,3 +387,17 @@ func (p *Provider) removeTask(_ context.Context, taskID int) error { return nil } + +func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, error) { + req := &godo.TagCreateRequest{ + Name: tagName, + } + + return p.doClient.CreateTag(ctx, req) +} + +func (p *Provider) GetState() ProviderState { + p.stateMu.Lock() + defer p.stateMu.Unlock() + return *p.state +} diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 246573be..628e7902 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "net/http" "strings" "sync" "testing" @@ -55,10 +54,10 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC NetworkMode: container.NetworkMode("host"), }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), "test-container").Return(container.CreateResponse{ID: "test-container"}, nil) - mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) mockDockerClients := map[string]DockerClient{ "10.0.0.1": mockDocker, @@ -81,8 +80,7 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC } var callCount int - mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("TagResource", ctx, mock.Anything, mock.Anything).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil) mockDO.On("GetDroplet", ctx, droplet.ID).Return(func(ctx context.Context, id int) *godo.Droplet { if callCount == 0 { callCount++ @@ -100,13 +98,11 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC } } return droplet - }, func(ctx context.Context, id int) *godo.Response { - return &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}} }, func(ctx context.Context, id int) error { return nil }).Maybe() - mockDO.On("DeleteDropletByID", ctx, droplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() + mockDO.On("DeleteDropletByID", ctx, droplet.ID).Return(nil).Maybe() return p, mockDO, mockDocker } @@ -141,10 +137,10 @@ func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { mockDO := mocks.NewDoClient(t) mockDocker := mocks.NewDockerClient(t) - mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) mockDockerClients := map[string]DockerClient{ "10.0.0.1": mockDocker, @@ -227,7 +223,7 @@ func TestSerializeAndRestoreTask(t *testing.T) { }, }, }, - }, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + }, nil) deserializedTask, err := p.DeserializeTask(ctx, taskData) assert.NoError(t, err) @@ -235,18 +231,20 @@ func TestSerializeAndRestoreTask(t *testing.T) { t1 := task.(*Task) t2 := deserializedTask.(*Task) + t1State := t1.GetState() + t2State := t2.GetState() - if configMap, ok := t2.state.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { + if configMap, ok := t2State.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { doConfig := make(DigitalOceanTaskConfig) for k, v := range configMap { doConfig[k] = v.(string) } - t2.state.Definition.ProviderSpecificConfig = doConfig + t2State.Definition.ProviderSpecificConfig = doConfig } - assert.Equal(t, t1.state, t2.state) + assert.Equal(t, t1State, t2State) assert.NotNil(t, t2.logger) - assert.NotNil(t, t2.sshKeyPair) + assert.NotNil(t, t2State.SSHKeyPair) assert.NotNil(t, t2.doClient) assert.NotNil(t, t2.dockerClient) assert.NotNil(t, t2.provider) @@ -294,13 +292,13 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { mockDocker.On("Close").Return(nil).Once() } - mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) - mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")). - Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) require.NoError(t, err) @@ -328,20 +326,16 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { }, } - mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusCreated}}, nil).Once() + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil).Once() // we cant predict how many times GetDroplet will be called exactly as the provider polls waiting for its creation - mockDO.On("GetDroplet", ctx, dropletID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() - mockDO.On("DeleteDropletByID", ctx, dropletID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusNoContent}}, nil).Once() + mockDO.On("GetDroplet", ctx, dropletID).Return(droplet, nil).Maybe() + mockDO.On("DeleteDropletByID", ctx, dropletID).Return(nil).Once() } - mockDO.On("DeleteDropletByTag", ctx, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteFirewall", ctx, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteKeyByFingerprint", ctx, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() - mockDO.On("DeleteTag", ctx, mock.AnythingOfType("string")). - Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Once() + mockDO.On("DeleteDropletByTag", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteFirewall", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteTag", ctx, mock.AnythingOfType("string")).Return(nil).Once() for i := 0; i < numTasks; i++ { wg.Add(1) @@ -403,7 +397,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { require.NoError(t, err) } - require.Equal(t, numTasks, len(p.state.TaskStates), "Provider state should contain all tasks") + require.Equal(t, numTasks, len(p.GetState().TaskStates), "Provider state should contain all tasks") var tasksToCleanup []*Task close(tasks) @@ -444,7 +438,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { } err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { - return len(p.state.TaskStates) == 0, nil + return len(p.GetState().TaskStates) == 0, nil }) require.NoError(t, err, "Provider state should be empty after cleanup") @@ -462,10 +456,10 @@ func TestProviderSerialization(t *testing.T) { mockDO := mocks.NewDoClient(t) mockDocker := mocks.NewDockerClient(t) - mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "petri-droplet-test"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, &godo.Response{Response: &http.Response{StatusCode: http.StatusNotFound}}, nil) - mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "petri-droplet-test"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) mockDockerClients := map[string]DockerClient{ "10.0.0.1": mockDocker, @@ -487,9 +481,8 @@ func TestProviderSerialization(t *testing.T) { Status: "active", } - mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("TagResource", ctx, mock.Anything, mock.Anything).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() - mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) @@ -522,12 +515,12 @@ func TestProviderSerialization(t *testing.T) { }) require.NoError(t, err) - state1 := p1.state + state1 := p1.GetState() serialized, err := p1.SerializeProvider(ctx) require.NoError(t, err) mockDO2 := mocks.NewDoClient(t) - mockDO2.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil).Maybe() + mockDO2.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() mockDocker2 := mocks.NewDockerClient(t) mockDocker2.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() @@ -539,7 +532,7 @@ func TestProviderSerialization(t *testing.T) { p2, err := RestoreProvider(ctx, "test-token", serialized, mockDO2, mockDockerClients2) require.NoError(t, err) - state2 := p2.state + state2 := p2.GetState() assert.Equal(t, state1.Name, state2.Name) assert.Equal(t, state1.PetriTag, state2.PetriTag) assert.Equal(t, state1.FirewallID, state2.FirewallID) diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index 78d34909..57af65f7 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -96,16 +96,6 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { } func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { - req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.state.PetriTag)} - - key, res, err := p.doClient.CreateKey(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return key, nil + req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.GetState().PetriTag)} + return p.doClient.CreateKey(ctx, req) } diff --git a/core/provider/digitalocean/tag.go b/core/provider/digitalocean/tag.go deleted file mode 100644 index 9d8654fb..00000000 --- a/core/provider/digitalocean/tag.go +++ /dev/null @@ -1,25 +0,0 @@ -package digitalocean - -import ( - "context" - "fmt" - - "github.com/digitalocean/godo" -) - -func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, error) { - req := &godo.TagCreateRequest{ - Name: tagName, - } - - tag, res, err := p.doClient.CreateTag(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return tag, nil -} diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 24f3eddc..fc0e9034 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -33,6 +33,7 @@ type TaskState struct { Definition provider.TaskDefinition `json:"definition"` Status provider.TaskStatus `json:"status"` ProviderName string `json:"provider_name"` + SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` } type Task struct { @@ -41,7 +42,6 @@ type Task struct { provider *Provider logger *zap.Logger - sshKeyPair *SSHKeyPair sshClient *ssh.Client doClient DoClient dockerClient DockerClient @@ -59,7 +59,7 @@ func (t *Task) Start(ctx context.Context) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", t.state.Name) + return fmt.Errorf("could not find container for %s", t.GetState().Name) } containerID := containers[0].ID @@ -75,18 +75,18 @@ func (t *Task) Start(ctx context.Context) error { return false, err } - if status == provider.TASK_RUNNING { - t.stateMu.Lock() - defer t.stateMu.Unlock() - - t.state.Status = provider.TASK_RUNNING - return true, nil + if status != provider.TASK_RUNNING { + return false, nil } - return false, nil + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_RUNNING + return true, nil }) - t.logger.Info("Final task status after start", zap.Any("status", t.state.Status)) + t.logger.Info("final task status after start", zap.Any("status", t.GetState().Status)) return err } @@ -100,7 +100,7 @@ func (t *Task) Stop(ctx context.Context) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", t.state.Name) + return fmt.Errorf("could not find container for %s", t.GetState().Name) } t.stateMu.Lock() @@ -119,7 +119,7 @@ func (t *Task) Modify(ctx context.Context, definition provider.TaskDefinition) e } func (t *Task) Destroy(ctx context.Context) error { - logger := t.logger.With(zap.String("task", t.state.Name)) + logger := t.logger.With(zap.String("task", t.GetState().Name)) logger.Info("deleting task") defer t.dockerClient.Close() @@ -129,7 +129,7 @@ func (t *Task) Destroy(ctx context.Context) error { } // TODO(nadim-az): remove reference to provider in Task struct - if err := t.provider.removeTask(ctx, t.state.ID); err != nil { + if err := t.provider.removeTask(ctx, t.GetState().ID); err != nil { return err } return nil @@ -164,7 +164,7 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { } if len(containers) != 1 { - return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", t.state.Name) + return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", t.GetState().Name) } c, err := t.dockerClient.ContainerInspect(ctx, containers[0].ID) @@ -195,7 +195,7 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) + sshClient, err := t.getDropletSSHClient(ctx, t.GetState().Name) if err != nil { return err } @@ -230,7 +230,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := t.getDropletSSHClient(ctx, t.state.Name) + sshClient, err := t.getDropletSSHClient(ctx, t.GetState().Name) if err != nil { return nil, err } @@ -290,6 +290,34 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in return t.runCommand(ctx, cmd) } +func waitForExec(ctx context.Context, dockerClient DockerClient, execID string) (int, error) { + lastExitCode := 0 + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + +loop: + for { + select { + case <-ctx.Done(): + return lastExitCode, ctx.Err() + case <-ticker.C: + execInspect, err := dockerClient.ContainerExecInspect(ctx, execID) + if err != nil { + return lastExitCode, err + } + + if execInspect.Running { + continue + } + + lastExitCode = execInspect.ExitCode + break loop + } + } + + return lastExitCode, nil +} + func (t *Task) runCommand(ctx context.Context, command []string) (string, string, int, error) { containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, @@ -300,7 +328,7 @@ func (t *Task) runCommand(ctx context.Context, command []string) (string, string } if len(containers) != 1 { - return "", "", 0, fmt.Errorf("could not find container for %s", t.state.Name) + return "", "", 0, fmt.Errorf("could not find container for %s", t.GetState().Name) } id := containers[0].ID @@ -323,80 +351,53 @@ func (t *Task) runCommand(ctx context.Context, command []string) (string, string defer resp.Close() - lastExitCode := 0 - - ticker := time.NewTicker(100 * time.Millisecond) - -loop: - for { - select { - case <-ctx.Done(): - return "", "", lastExitCode, ctx.Err() - case <-ticker.C: - execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) - if err != nil { - return "", "", lastExitCode, err - } - - if execInspect.Running { - continue - } - - lastExitCode = execInspect.ExitCode - - break loop - } - } - var stdout, stderr bytes.Buffer - _, err = stdcopy.StdCopy(&stdout, &stderr, resp.Reader) if err != nil { - return "", "", lastExitCode, err + return "", "", 0, err + } + + exitCode, err := waitForExec(ctx, t.dockerClient, exec.ID) + if err != nil { + return stdout.String(), stderr.String(), exitCode, err } - return stdout.String(), stderr.String(), lastExitCode, nil + return stdout.String(), stderr.String(), exitCode, nil } func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { - if err := t.state.Definition.ValidateBasic(); err != nil { + state := t.GetState() + if err := state.Definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - t.stateMu.Lock() - defer t.stateMu.Unlock() - - t.state.Definition.Entrypoint = []string{"sh", "-c"} - t.state.Definition.Command = []string{"sleep 36000"} - t.state.Definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", t.state.Definition.Name, util.RandomString(5), time.Now().Unix()) - t.state.Definition.Ports = []string{} - + containerName := fmt.Sprintf("%s-executor-%s-%d", state.Definition.Name, util.RandomString(5), time.Now().Unix()) createdContainer, err := t.dockerClient.ContainerCreate(ctx, &container.Config{ - Image: t.state.Definition.Image.Image, - Entrypoint: t.state.Definition.Entrypoint, - Cmd: t.state.Definition.Command, + Image: state.Definition.Image.Image, + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, Tty: false, - Hostname: t.state.Definition.Name, + Hostname: state.Definition.Name, Labels: map[string]string{ - providerLabelName: t.state.ProviderName, + providerLabelName: state.ProviderName, }, - Env: convertEnvMapToList(t.state.Definition.Environment), + Env: convertEnvMapToList(state.Definition.Environment), }, &container.HostConfig{ Mounts: []mount.Mount{ { Type: mount.TypeBind, Source: "/docker_volumes", - Target: t.state.Definition.DataDir, + Target: state.Definition.DataDir, }, }, NetworkMode: container.NetworkMode("host"), - }, nil, nil, t.state.Definition.ContainerName) + }, nil, nil, containerName) if err != nil { - t.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", t.state.Name)) + t.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } - t.logger.Debug("container created successfully", zap.String("id", createdContainer.ID), zap.String("taskName", t.state.Name)) + t.logger.Debug("container created successfully", zap.String("id", createdContainer.ID), zap.String("taskName", state.Name)) defer func() { if _, err := t.dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { @@ -405,16 +406,16 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string } if err := t.dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { - t.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", t.state.Name), zap.String("id", createdContainer.ID)) + t.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", state.Name), zap.String("id", createdContainer.ID)) } }() if err := startContainerWithBlock(ctx, t.dockerClient, createdContainer.ID); err != nil { - t.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", t.state.Name)) + t.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } - t.logger.Debug("container started successfully", zap.String("id", createdContainer.ID), zap.String("taskName", t.state.Name)) + t.logger.Debug("container started successfully", zap.String("id", createdContainer.ID), zap.String("taskName", state.Name)) // wait for container start exec, err := t.dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ @@ -423,50 +424,30 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string Cmd: cmd, }) if err != nil { - t.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", t.state.Name)) + t.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { - t.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", t.state.Name)) + t.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } defer resp.Close() - lastExitCode := 0 - - ticker := time.NewTicker(100 * time.Millisecond) - -loop: - for { - select { - case <-ctx.Done(): - return "", "", lastExitCode, ctx.Err() - case <-ticker.C: - execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) - if err != nil { - return "", "", lastExitCode, err - } - - if execInspect.Running { - continue - } - - lastExitCode = execInspect.ExitCode - - break loop - } - } - var stdout, stderr bytes.Buffer _, err = stdcopy.StdCopy(&stdout, &stderr, resp.Reader) if err != nil { return "", "", 0, err } - return stdout.String(), stderr.String(), lastExitCode, err + exitCode, err := waitForExec(ctx, t.dockerClient, exec.ID) + if err != nil { + return stdout.String(), stderr.String(), exitCode, err + } + + return stdout.String(), stderr.String(), exitCode, nil } func startContainerWithBlock(ctx context.Context, dockerClient DockerClient, containerID string) error { diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index 73b39d19..fdc2fbc8 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -71,7 +71,7 @@ func TestTaskLifecycle(t *testing.T) { }, } - mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -131,7 +131,7 @@ func TestTaskRunCommand(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -187,6 +187,25 @@ func TestTaskRunCommand(t *testing.T) { require.Equal(t, 0, exitCode) require.Empty(t, stderr) + // Start command assertions + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + err = task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + mockDocker.AssertExpectations(t) mockDO.AssertExpectations(t) } @@ -232,7 +251,7 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { }, }, nil).Twice() - mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) execCreateResp := types.IDResponse{ID: "test-exec-id"} mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ @@ -286,7 +305,27 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { require.Equal(t, 0, exitCode) require.Empty(t, stderr) + // Start command assertions + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + err = task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) } func TestTaskGetIP(t *testing.T) { @@ -310,7 +349,7 @@ func TestTaskGetIP(t *testing.T) { }, } - mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) task := &Task{ state: &TaskState{ @@ -341,8 +380,8 @@ func TestTaskDestroy(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - mockDO.On("GetDroplet", ctx, testDroplet.ID).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - mockDO.On("DeleteDropletByID", ctx, testDroplet.ID).Return(&godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, testDroplet.ID).Return(testDroplet, nil) + mockDO.On("DeleteDropletByID", ctx, testDroplet.ID).Return(nil) mockDocker.On("Close").Return(nil) provider := &Provider{ @@ -363,7 +402,7 @@ func TestTaskDestroy(t *testing.T) { provider: provider, } - provider.state.TaskStates[task.state.ID] = task.state + provider.state.TaskStates[task.GetState().ID] = task.state err := task.Destroy(ctx) require.NoError(t, err) @@ -380,7 +419,7 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -487,7 +526,7 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { mockDocker := mocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) - mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -604,7 +643,7 @@ func TestTaskExposingPort(t *testing.T) { }, } - mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) testContainer := types.Container{ ID: testContainerID, @@ -696,8 +735,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "off", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletOff.ID).Return(testDropletOff, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, testDropletOff.ID).Return(testDropletOff, nil) }, expectedStatus: provider.TASK_STOPPED, expectError: false, @@ -707,9 +745,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "running", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -729,8 +765,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "paused", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -750,9 +785,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "exited", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -772,9 +805,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "removing", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -794,9 +825,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "dead", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -816,9 +845,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "created", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -838,9 +865,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "unknown_status", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -860,7 +885,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{}, nil) @@ -873,9 +898,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) - + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return([]types.Container{testContainer}, nil) @@ -889,7 +912,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - mockDO.On("GetDroplet", ctx, 123).Return(nil, nil, fmt.Errorf("failed to get droplet")) + mockDO.On("GetDroplet", ctx, 123).Return(nil, fmt.Errorf("failed to get droplet")) }, expectedStatus: provider.TASK_STATUS_UNDEFINED, expectError: true, @@ -899,8 +922,7 @@ func TestGetStatus(t *testing.T) { dropletStatus: "active", containerState: "", setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { - - mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, &godo.Response{Response: &http.Response{StatusCode: http.StatusOK}}, nil) + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, }).Return(nil, fmt.Errorf("failed to list containers")) From 9cd4c3a54c1a1a60297c8095bd038904c6d63ffb Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 00:13:53 +0200 Subject: [PATCH 20/50] use zap fields in log statement --- core/provider/digitalocean/droplet.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 35a6e8bd..77df050d 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -98,7 +98,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe end := time.Now() - p.logger.Info(fmt.Sprintf("droplet %s is ready after %s", droplet.Name, end.Sub(start))) + p.logger.Info("droplet is ready after", zap.String("droplet_name", droplet.Name), zap.Duration("startup_time", end.Sub(start))) return droplet, nil } From 5e34abe1b724d0e62b3ad0c2ab537ba0420365b8 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 00:21:15 +0200 Subject: [PATCH 21/50] reorder import --- core/provider/digitalocean/examples/digitalocean_simapp.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/provider/digitalocean/examples/digitalocean_simapp.go b/core/provider/digitalocean/examples/digitalocean_simapp.go index cc00c298..766c065a 100644 --- a/core/provider/digitalocean/examples/digitalocean_simapp.go +++ b/core/provider/digitalocean/examples/digitalocean_simapp.go @@ -2,10 +2,11 @@ package main import ( "context" + "os" + "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/skip-mev/petri/cosmos/v2/chain" "github.com/skip-mev/petri/cosmos/v2/node" - "os" "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean" From bd3d7d8ff691984ca83ff039ff2d0fcc24fd3b23 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 00:27:08 +0200 Subject: [PATCH 22/50] clean code plz --- core/provider/digitalocean/provider.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 087a3ca7..f0bde274 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -55,12 +55,12 @@ func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, t // NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. // This is primarily used for testing. func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + var err error if sshKeyPair == nil { - newSshKeyPair, err := MakeSSHKeyPair() + sshKeyPair, err = MakeSSHKeyPair() if err != nil { return nil, err } - sshKeyPair = newSshKeyPair } userIPs, err := getUserIPs(ctx) From 0845905a7b4035e12bd848a30924453cd98cb6d5 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 02:02:21 +0200 Subject: [PATCH 23/50] use getState everywhere --- core/provider/digitalocean/provider.go | 24 ++++++++++++++---------- core/provider/digitalocean/task_test.go | 5 +++-- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index f0bde274..90fd3f87 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -74,20 +74,14 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName dockerClients = make(map[string]DockerClient) } + petriTag := fmt.Sprintf("petri-droplet-%s", util.RandomString(5)) digitalOceanProvider := &Provider{ logger: logger.Named("digitalocean_provider"), doClient: doClient, dockerClients: dockerClients, - state: &ProviderState{ - TaskStates: make(map[int]*TaskState), - UserIPs: userIPs, - Name: providerName, - SSHKeyPair: sshKeyPair, - PetriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), - }, } - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.state.PetriTag) + _, err = digitalOceanProvider.createTag(ctx, petriTag) if err != nil { return nil, err } @@ -97,7 +91,16 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName return nil, fmt.Errorf("failed to create firewall: %w", err) } - digitalOceanProvider.state.FirewallID = firewall.ID + pState := &ProviderState{ + TaskStates: make(map[int]*TaskState), + UserIPs: userIPs, + Name: providerName, + SSHKeyPair: sshKeyPair, + PetriTag: petriTag, + FirewallID: firewall.ID, + } + + digitalOceanProvider.state = pState //TODO(Zygimantass): TOCTOU issue if key, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { @@ -319,7 +322,8 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas } func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) error { - task.logger = p.logger.With(zap.String("task", task.state.Name)) + taskState := task.GetState() + task.logger = p.logger.With(zap.String("task", taskState.Name)) task.doClient = p.doClient task.provider = p diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index fdc2fbc8..68b28dee 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -389,6 +389,7 @@ func TestTaskDestroy(t *testing.T) { TaskStates: make(map[int]*TaskState), }, } + providerState := provider.GetState() task := &Task{ state: &TaskState{ @@ -402,11 +403,11 @@ func TestTaskDestroy(t *testing.T) { provider: provider, } - provider.state.TaskStates[task.GetState().ID] = task.state + providerState.TaskStates[task.GetState().ID] = task.state err := task.Destroy(ctx) require.NoError(t, err) - require.Empty(t, provider.state.TaskStates) + require.Empty(t, providerState.TaskStates) mockDocker.AssertExpectations(t) mockDO.AssertExpectations(t) From a1198183ea8e3ad87cf4a8e032c29929e38b6399 Mon Sep 17 00:00:00 2001 From: Zygimantas <zygis@skip.build> Date: Tue, 21 Jan 2025 13:42:50 +0100 Subject: [PATCH 24/50] fix: init provider state before using it in creating the firewall --- core/provider/digitalocean/provider.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 90fd3f87..dd33f38d 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -81,6 +81,16 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName dockerClients: dockerClients, } + pState := &ProviderState{ + TaskStates: make(map[int]*TaskState), + UserIPs: userIPs, + Name: providerName, + SSHKeyPair: sshKeyPair, + PetriTag: petriTag, + } + + digitalOceanProvider.state = pState + _, err = digitalOceanProvider.createTag(ctx, petriTag) if err != nil { return nil, err @@ -91,16 +101,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName return nil, fmt.Errorf("failed to create firewall: %w", err) } - pState := &ProviderState{ - TaskStates: make(map[int]*TaskState), - UserIPs: userIPs, - Name: providerName, - SSHKeyPair: sshKeyPair, - PetriTag: petriTag, - FirewallID: firewall.ID, - } - - digitalOceanProvider.state = pState + pState.FirewallID = firewall.ID //TODO(Zygimantass): TOCTOU issue if key, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { From d6c2a1efb10990397186b8b42de3cc95c305733a Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 17:52:00 +0200 Subject: [PATCH 25/50] droplet id in task states should be a string --- core/provider/digitalocean/droplet.go | 6 ++++- core/provider/digitalocean/provider.go | 26 +++++++++++++-------- core/provider/digitalocean/provider_test.go | 4 ++-- core/provider/digitalocean/task.go | 2 +- core/provider/digitalocean/task_test.go | 21 +++++++++-------- 5 files changed, 35 insertions(+), 24 deletions(-) diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 77df050d..01d0cd88 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -113,7 +113,11 @@ func (t *Task) deleteDroplet(ctx context.Context) error { } func (t *Task) getDroplet(ctx context.Context) (*godo.Droplet, error) { - return t.doClient.GetDroplet(ctx, t.GetState().ID) + dropletId, err := strconv.Atoi(t.GetState().ID) + if err != nil { + return nil, err + } + return t.doClient.GetDroplet(ctx, dropletId) } func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index dd33f38d..12fa8529 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "strconv" "strings" "sync" "time" @@ -28,12 +29,12 @@ const ( ) type ProviderState struct { - TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state - Name string `json:"name"` - PetriTag string `json:"petri_tag"` - UserIPs []string `json:"user_ips"` - SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` - FirewallID string `json:"firewall_id"` + TaskStates map[string]*TaskState `json:"task_states"` // map of task ids to the corresponding task state + Name string `json:"name"` + PetriTag string `json:"petri_tag"` + UserIPs []string `json:"user_ips"` + SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` + FirewallID string `json:"firewall_id"` } type Provider struct { @@ -82,7 +83,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName } pState := &ProviderState{ - TaskStates: make(map[int]*TaskState), + TaskStates: make(map[string]*TaskState), UserIPs: userIPs, Name: providerName, SSHKeyPair: sshKeyPair, @@ -206,7 +207,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin } taskState := &TaskState{ - ID: droplet.ID, + ID: strconv.Itoa(droplet.ID), Name: definition.Name, Definition: definition, Status: provider.TASK_STOPPED, @@ -265,7 +266,12 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } for _, taskState := range providerState.TaskStates { - droplet, err := digitalOceanProvider.doClient.GetDroplet(ctx, taskState.ID) + id, err := strconv.Atoi(taskState.ID) + if err != nil { + return nil, err + } + + droplet, err := digitalOceanProvider.doClient.GetDroplet(ctx, id) if err != nil { return nil, fmt.Errorf("failed to get droplet for task state: %w", err) } @@ -384,7 +390,7 @@ func (p *Provider) teardownTag(ctx context.Context) error { return p.doClient.DeleteTag(ctx, p.GetState().PetriTag) } -func (p *Provider) removeTask(_ context.Context, taskID int) error { +func (p *Provider) removeTask(_ context.Context, taskID string) error { p.stateMu.Lock() defer p.stateMu.Unlock() diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 628e7902..bfe04c08 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -308,7 +308,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { errors := make(chan error, numTasks) tasks := make(chan *Task, numTasks) taskMutex := sync.Mutex{} - dropletIDs := make(map[int]bool) + dropletIDs := make(map[string]bool) ipAddresses := make(map[string]bool) for i := 0; i < numTasks; i++ { @@ -373,7 +373,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { state := doTask.GetState() if dropletIDs[state.ID] { - errors <- fmt.Errorf("duplicate droplet ID found: %d", state.ID) + errors <- fmt.Errorf("duplicate droplet ID found: %s", state.ID) } dropletIDs[state.ID] = true diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index fc0e9034..156e887a 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -28,7 +28,7 @@ import ( ) type TaskState struct { - ID int `json:"id"` + ID string `json:"id"` Name string `json:"name"` Definition provider.TaskDefinition `json:"definition"` Status provider.TaskStatus `json:"status"` diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index 68b28dee..b0b61098 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "net/http" + "strconv" "testing" "time" @@ -90,7 +91,7 @@ func TestTaskLifecycle(t *testing.T) { task := &Task{ state: &TaskState{ - ID: droplet.ID, + ID: strconv.Itoa(droplet.ID), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -165,7 +166,7 @@ func TestTaskRunCommand(t *testing.T) { task := &Task{ state: &TaskState{ - ID: 1, + ID: strconv.Itoa(1), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -282,7 +283,7 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { task := &Task{ state: &TaskState{ - ID: 1, + ID: strconv.Itoa(1), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -353,7 +354,7 @@ func TestTaskGetIP(t *testing.T) { task := &Task{ state: &TaskState{ - ID: droplet.ID, + ID: strconv.Itoa(droplet.ID), Name: "test-task", ProviderName: "test-provider", }, @@ -386,14 +387,14 @@ func TestTaskDestroy(t *testing.T) { provider := &Provider{ state: &ProviderState{ - TaskStates: make(map[int]*TaskState), + TaskStates: make(map[string]*TaskState), }, } providerState := provider.GetState() task := &Task{ state: &TaskState{ - ID: testDroplet.ID, + ID: strconv.Itoa(testDroplet.ID), Name: "test-task", ProviderName: "test-provider", }, @@ -492,7 +493,7 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { task := &Task{ state: &TaskState{ - ID: 1, + ID: strconv.Itoa(1), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -597,7 +598,7 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { task := &Task{ state: &TaskState{ - ID: 1, + ID: strconv.Itoa(1), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -672,7 +673,7 @@ func TestTaskExposingPort(t *testing.T) { task := &Task{ state: &TaskState{ - ID: droplet.ID, + ID: strconv.Itoa(droplet.ID), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ @@ -942,7 +943,7 @@ func TestGetStatus(t *testing.T) { task := &Task{ state: &TaskState{ - ID: 123, + ID: strconv.Itoa(123), Name: "test-task", ProviderName: "test-provider", Definition: provider.TaskDefinition{ From f1ab792b9e7e149980482efcf5b870b3c2d1347e Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 18:39:12 +0200 Subject: [PATCH 26/50] get public ip address through ifconfig instead of user input --- .../examples/digitalocean_simapp.go | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/core/provider/digitalocean/examples/digitalocean_simapp.go b/core/provider/digitalocean/examples/digitalocean_simapp.go index 766c065a..b1d45971 100644 --- a/core/provider/digitalocean/examples/digitalocean_simapp.go +++ b/core/provider/digitalocean/examples/digitalocean_simapp.go @@ -2,9 +2,13 @@ package main import ( "context" + "io" + "net/http" "os" + "strings" "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/skip-mev/petri/cosmos/v2/chain" "github.com/skip-mev/petri/cosmos/v2/node" @@ -33,8 +37,10 @@ func main() { logger.Fatal("failed to create SSH key pair", zap.Error(err)) } - // Add your IP address below - doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{"INSERT IP ADDRESS HERE"}, sshKeyPair) + externalIP, err := getExternalIP() + logger.Info("External IP", zap.String("address", externalIP)) + + doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{externalIP}, sshKeyPair) if err != nil { logger.Fatal("failed to create DigitalOcean provider", zap.Error(err)) } @@ -106,3 +112,17 @@ func main() { logger.Info("All Digital Ocean resources created have been successfully deleted!") } + +func getExternalIP() (string, error) { + resp, err := http.Get("https://ifconfig.me") + if err != nil { + return "", err + } + defer resp.Body.Close() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return strings.TrimSpace(string(ip)), nil +} From ac0c8dc23dfe6431db096d5c0685e1f3b7698936 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 21:27:54 +0200 Subject: [PATCH 27/50] missing logger in restorechain --- cosmos/chain/chain.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cosmos/chain/chain.go b/cosmos/chain/chain.go index 4f00e5c4..3e917b82 100644 --- a/cosmos/chain/chain.go +++ b/cosmos/chain/chain.go @@ -150,7 +150,8 @@ func RestoreChain(ctx context.Context, logger *zap.Logger, infraProvider provide } chain := Chain{ - State: packagedState.State, + State: packagedState.State, + logger: logger, } for _, vs := range packagedState.ValidatorStates { From 9737a09496a07dbc924666fe7a168715d14ce743 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 22:38:39 +0200 Subject: [PATCH 28/50] test: add e2e test --- core/tests/e2e/digitalocean/do_test.go | 334 +++++++++++++++++++++++++ core/tests/e2e/docker/docker_test.go | 286 +++++++++++++++++++++ 2 files changed, 620 insertions(+) create mode 100644 core/tests/e2e/digitalocean/do_test.go create mode 100644 core/tests/e2e/docker/docker_test.go diff --git a/core/tests/e2e/digitalocean/do_test.go b/core/tests/e2e/digitalocean/do_test.go new file mode 100644 index 00000000..88d63879 --- /dev/null +++ b/core/tests/e2e/digitalocean/do_test.go @@ -0,0 +1,334 @@ +package e2e + +import ( + "context" + "flag" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "testing" + "time" + + "github.com/skip-mev/petri/cosmos/v2/node" + + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean" + "github.com/skip-mev/petri/core/v2/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var ( + defaultChainConfig = types.ChainConfig{ + Denom: "stake", + Decimals: 6, + NumValidators: 1, + NumNodes: 1, + BinaryName: "/usr/bin/simd", + Image: provider.ImageDefinition{ + Image: "interchainio/simapp:latest", + UID: "1000", + GID: "1000", + }, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + } + + defaultChainOptions = types.ChainOptions{ + NodeCreator: node.CreateNode, + WalletConfig: types.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, + NodeOptions: types.NodeOptions{ + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig types.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": os.Getenv("DO_IMAGE_ID"), + } + def.ProviderSpecificConfig = doConfig + return def + }, + }, + } + + numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") + numNodes = flag.Int("num-nodes", 1, "number of nodes per chain") + numValidators = flag.Int("num-validators", 1, "number of validators per chain") +) + +func getExternalIP() (string, error) { + resp, err := http.Get("https://ifconfig.me") + if err != nil { + return "", err + } + defer resp.Body.Close() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return strings.TrimSpace(string(ip)), nil +} + +func TestDOE2E(t *testing.T) { + if !flag.Parsed() { + flag.Parse() + } + + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + doToken := os.Getenv("DO_API_TOKEN") + if doToken == "" { + logger.Fatal("DO_API_TOKEN environment variable not set") + } + + imageID := os.Getenv("DO_IMAGE_ID") + if imageID == "" { + logger.Fatal("DO_IMAGE_ID environment variable not set") + } + + externalIP, err := getExternalIP() + logger.Info("External IP", zap.String("address", externalIP)) + require.NoError(t, err) + + p, err := digitalocean.NewProvider(ctx, logger, "digitalocean_provider", doToken, []string{externalIP}, nil) + require.NoError(t, err) + + defer func() { + dockerClient, err := client.NewClientWithOpts() + if err != nil { + t.Logf("Failed to create Docker client for volume cleanup: %v", err) + return + } + _, err = dockerClient.VolumesPrune(ctx, filters.Args{}) + if err != nil { + t.Logf("Failed to prune volumes: %v", err) + } + }() + + var wg sync.WaitGroup + chainErrors := make(chan error, *numTestChains*2) + chains := make([]*cosmoschain.Chain, *numTestChains) + + // Create first half of chains + for i := 0; i < *numTestChains/2; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + chainConfig := defaultChainConfig + chainConfig.ChainId = fmt.Sprintf("chain-%d", index) + chainConfig.NumNodes = *numNodes + chainConfig.NumValidators = *numValidators + c, err := cosmoschain.CreateChain(ctx, logger, p, chainConfig, defaultChainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, defaultChainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) + + serializedProvider, err := p.SerializeProvider(ctx) + require.NoError(t, err) + restoredProvider, err := digitalocean.RestoreProvider(ctx, doToken, serializedProvider, nil, nil) + require.NoError(t, err) + + // Create second half of chains with restored provider + for i := *numTestChains / 2; i < *numTestChains; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + chainConfig := defaultChainConfig + chainConfig.ChainId = fmt.Sprintf("chain-%d", index) + chainConfig.NumNodes = *numNodes + chainConfig.NumValidators = *numValidators + c, err := cosmoschain.CreateChain(ctx, logger, restoredProvider, chainConfig, defaultChainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, defaultChainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) + + // Serialize and restore all chains with the restored provider + restoredChains := make([]*cosmoschain.Chain, *numTestChains) + for i := 0; i < *numTestChains; i++ { + chainState, err := chains[i].Serialize(ctx, restoredProvider) + require.NoError(t, err) + + restoredChain, err := cosmoschain.RestoreChain(ctx, logger, restoredProvider, chainState, node.RestoreNode) + require.NoError(t, err) + + require.Equal(t, chains[i].GetConfig(), restoredChain.GetConfig()) + require.Equal(t, len(chains[i].GetValidators()), len(restoredChain.GetValidators())) + + restoredChains[i] = restoredChain + } + + // Test and teardown half the chains individually + for i := 0; i < *numTestChains/2; i++ { + originalChain := chains[i] + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := validator.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = validator.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := validator.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + + // Test individual chain teardown + err = originalChain.Teardown(ctx) + require.NoError(t, err) + + // wait for task statuses to update on DO client side + time.Sleep(30 * time.Second) + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + logger.Info("validator status", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as droplet isn't available") + + _, err = validator.GetIP(ctx) + require.Error(t, err, "validator IP should not be accessible after teardown") + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + logger.Info("node status", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as droplet isn't available") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") + } + } + + // Test the remaining chains but let the provider teardown handle their cleanup + remainingChains := make([]*cosmoschain.Chain, 0) + for i := *numTestChains / 2; i < *numTestChains; i++ { + originalChain := chains[i] + remainingChains = append(remainingChains, originalChain) + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := validator.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = validator.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := validator.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) + } + for _, node := range nodes { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + } + + require.NoError(t, restoredProvider.Teardown(ctx)) + time.Sleep(30 * time.Second) + + // Verify all remaining chains are properly torn down + for _, chain := range remainingChains { + validators := chain.GetValidators() + nodes := chain.GetNodes() + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + logger.Info("validator status after provider teardown", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as droplet isn't available") + + _, err = validator.GetIP(ctx) + require.Error(t, err, "validator IP should not be accessible after teardown") + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + logger.Info("node status after provider teardown", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as droplet isn't available") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") + } + } +} diff --git a/core/tests/e2e/docker/docker_test.go b/core/tests/e2e/docker/docker_test.go new file mode 100644 index 00000000..dc84133a --- /dev/null +++ b/core/tests/e2e/docker/docker_test.go @@ -0,0 +1,286 @@ +package e2e + +import ( + "context" + "flag" + "fmt" + "sync" + "testing" + + "github.com/skip-mev/petri/cosmos/v2/node" + + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/docker" + "github.com/skip-mev/petri/core/v2/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var ( + defaultChainConfig = types.ChainConfig{ + Denom: "stake", + Decimals: 6, + NumValidators: 1, + NumNodes: 1, + BinaryName: "/usr/bin/simd", + Image: provider.ImageDefinition{ + Image: "interchainio/simapp:latest", + UID: "1000", + GID: "1000", + }, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + UseGenesisSubCommand: false, + } + + defaultChainOptions = types.ChainOptions{ + NodeCreator: node.CreateNode, + WalletConfig: types.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, + } + + numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") + numNodes = flag.Int("num-nodes", 1, "number of nodes per chain") + numValidators = flag.Int("num-validators", 1, "number of validators per chain") +) + +func TestDockerE2E(t *testing.T) { + if !flag.Parsed() { + flag.Parse() + } + + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + p, err := docker.CreateProvider(ctx, logger, "docker_provider") + require.NoError(t, err) + + defer func() { + dockerClient, err := client.NewClientWithOpts() + if err != nil { + t.Logf("Failed to create Docker client for volume cleanup: %v", err) + return + } + _, err = dockerClient.VolumesPrune(ctx, filters.Args{}) + if err != nil { + t.Logf("Failed to prune volumes: %v", err) + } + }() + + var wg sync.WaitGroup + chainErrors := make(chan error, *numTestChains*2) + chains := make([]*cosmoschain.Chain, *numTestChains) + + // Create first half of chains + for i := 0; i < *numTestChains/2; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + chainConfig := defaultChainConfig + chainConfig.ChainId = fmt.Sprintf("chain-%d", index) + chainConfig.NumNodes = *numNodes + chainConfig.NumValidators = *numValidators + c, err := cosmoschain.CreateChain(ctx, logger, p, chainConfig, defaultChainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, defaultChainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) + + serializedProvider, err := p.SerializeProvider(ctx) + require.NoError(t, err) + restoredProvider, err := docker.RestoreProvider(ctx, logger, serializedProvider) + require.NoError(t, err) + + // Create second half of chains with restored provider + for i := *numTestChains / 2; i < *numTestChains; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + chainConfig := defaultChainConfig + chainConfig.ChainId = fmt.Sprintf("chain-%d", index) + chainConfig.NumNodes = *numNodes + chainConfig.NumValidators = *numValidators + c, err := cosmoschain.CreateChain(ctx, logger, restoredProvider, chainConfig, defaultChainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, defaultChainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) + + // Serialize and restore all chains with the restored provider + restoredChains := make([]*cosmoschain.Chain, *numTestChains) + for i := 0; i < *numTestChains; i++ { + chainState, err := chains[i].Serialize(ctx, restoredProvider) + require.NoError(t, err) + + restoredChain, err := cosmoschain.RestoreChain(ctx, logger, restoredProvider, chainState, node.RestoreNode) + require.NoError(t, err) + + require.Equal(t, chains[i].GetConfig(), restoredChain.GetConfig()) + require.Equal(t, len(chains[i].GetValidators()), len(restoredChain.GetValidators())) + + restoredChains[i] = restoredChain + } + + // Test and teardown half the chains individually + for i := 0; i < *numTestChains/2; i++ { + originalChain := restoredChains[i] + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := validator.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = validator.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := validator.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + + // Test individual chain teardown + err = originalChain.Teardown(ctx) + require.NoError(t, err) + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + logger.Info("validator status", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as container isn't available") + + _, err = validator.GetIP(ctx) + require.Error(t, err, "validator IP should not be accessible after teardown") + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + logger.Info("node status", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as container isn't available") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") + } + } + + // Test the remaining chains but let the provider teardown handle their cleanup + remainingChains := make([]*cosmoschain.Chain, 0) + for i := *numTestChains / 2; i < *numTestChains; i++ { + originalChain := restoredChains[i] + remainingChains = append(remainingChains, originalChain) + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := validator.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = validator.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := validator.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) + } + for _, node := range nodes { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + } + + require.NoError(t, restoredProvider.Teardown(ctx)) + // Verify all remaining chains are properly torn down + for _, chain := range remainingChains { + validators := chain.GetValidators() + nodes := chain.GetNodes() + + for _, validator := range validators { + status, err := validator.GetStatus(ctx) + logger.Info("validator status after provider teardown", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as container isn't available") + + _, err = validator.GetIP(ctx) + require.Error(t, err, "validator IP should not be accessible after teardown") + } + + for _, node := range nodes { + status, err := node.GetStatus(ctx) + logger.Info("node status after provider teardown", zap.Any("", status)) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as container isn't available") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") + } + } +} From c0459e08291831092a155841ac9b1b374cb8db3c Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 22 Jan 2025 01:28:34 +0200 Subject: [PATCH 29/50] refactor --- core/tests/e2e/digitalocean/do_test.go | 181 +++---------------------- core/tests/e2e/docker/docker_test.go | 148 +++----------------- core/tests/e2e/utils.go | 97 +++++++++++++ 3 files changed, 135 insertions(+), 291 deletions(-) create mode 100644 core/tests/e2e/utils.go diff --git a/core/tests/e2e/digitalocean/do_test.go b/core/tests/e2e/digitalocean/do_test.go index 88d63879..20c877a9 100644 --- a/core/tests/e2e/digitalocean/do_test.go +++ b/core/tests/e2e/digitalocean/do_test.go @@ -3,20 +3,15 @@ package e2e import ( "context" "flag" - "fmt" - "io" - "net/http" "os" - "strings" - "sync" "testing" "time" + "github.com/skip-mev/petri/core/v2/tests/e2e" + "github.com/skip-mev/petri/cosmos/v2/node" "github.com/cosmos/cosmos-sdk/crypto/hd" - "github.com/docker/docker/api/types/filters" - "github.com/docker/docker/client" "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean" "github.com/skip-mev/petri/core/v2/types" @@ -71,20 +66,6 @@ var ( numValidators = flag.Int("num-validators", 1, "number of validators per chain") ) -func getExternalIP() (string, error) { - resp, err := http.Get("https://ifconfig.me") - if err != nil { - return "", err - } - defer resp.Body.Close() - - ip, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - return strings.TrimSpace(string(ip)), nil -} - func TestDOE2E(t *testing.T) { if !flag.Parsed() { flag.Parse() @@ -103,85 +84,28 @@ func TestDOE2E(t *testing.T) { logger.Fatal("DO_IMAGE_ID environment variable not set") } - externalIP, err := getExternalIP() + externalIP, err := e2e.GetExternalIP() logger.Info("External IP", zap.String("address", externalIP)) require.NoError(t, err) p, err := digitalocean.NewProvider(ctx, logger, "digitalocean_provider", doToken, []string{externalIP}, nil) require.NoError(t, err) - defer func() { - dockerClient, err := client.NewClientWithOpts() - if err != nil { - t.Logf("Failed to create Docker client for volume cleanup: %v", err) - return - } - _, err = dockerClient.VolumesPrune(ctx, filters.Args{}) - if err != nil { - t.Logf("Failed to prune volumes: %v", err) - } - }() - - var wg sync.WaitGroup - chainErrors := make(chan error, *numTestChains*2) chains := make([]*cosmoschain.Chain, *numTestChains) // Create first half of chains - for i := 0; i < *numTestChains/2; i++ { - wg.Add(1) - go func(index int) { - defer wg.Done() - chainConfig := defaultChainConfig - chainConfig.ChainId = fmt.Sprintf("chain-%d", index) - chainConfig.NumNodes = *numNodes - chainConfig.NumValidators = *numValidators - c, err := cosmoschain.CreateChain(ctx, logger, p, chainConfig, defaultChainOptions) - if err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) - return - } - if err := c.Init(ctx, defaultChainOptions); err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) - return - } - chains[index] = c - }(i) - } - wg.Wait() - require.Empty(t, chainErrors) + defaultChainConfig.NumNodes = *numNodes + defaultChainConfig.NumValidators = *numValidators + e2e.CreateChainsConcurrently(ctx, t, logger, p, 0, *numTestChains/2, chains, defaultChainConfig, defaultChainOptions) + // Restore provider before creating second half of chains serializedProvider, err := p.SerializeProvider(ctx) require.NoError(t, err) restoredProvider, err := digitalocean.RestoreProvider(ctx, doToken, serializedProvider, nil, nil) require.NoError(t, err) // Create second half of chains with restored provider - for i := *numTestChains / 2; i < *numTestChains; i++ { - wg.Add(1) - go func(index int) { - defer wg.Done() - chainConfig := defaultChainConfig - chainConfig.ChainId = fmt.Sprintf("chain-%d", index) - chainConfig.NumNodes = *numNodes - chainConfig.NumValidators = *numValidators - c, err := cosmoschain.CreateChain(ctx, logger, restoredProvider, chainConfig, defaultChainOptions) - if err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) - return - } - if err := c.Init(ctx, defaultChainOptions); err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) - return - } - chains[index] = c - }(i) - } - wg.Wait() - require.Empty(t, chainErrors) + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, chains, defaultChainConfig, defaultChainOptions) // Serialize and restore all chains with the restored provider restoredChains := make([]*cosmoschain.Chain, *numTestChains) @@ -205,32 +129,11 @@ func TestDOE2E(t *testing.T) { nodes := originalChain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := validator.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) - - testFile := "test.txt" - testContent := []byte("test content") - err = validator.WriteFile(ctx, testFile, testContent) - require.NoError(t, err) - - readContent, err := validator.ReadFile(ctx, testFile) - require.NoError(t, err) - require.Equal(t, testContent, readContent) + e2e.AssertNodeRunning(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := node.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) + e2e.AssertNodeRunning(t, ctx, node) } err = originalChain.WaitForBlocks(ctx, 2) @@ -240,27 +143,15 @@ func TestDOE2E(t *testing.T) { err = originalChain.Teardown(ctx) require.NoError(t, err) - // wait for task statuses to update on DO client side - time.Sleep(30 * time.Second) + // wait for status to update on DO client side + time.Sleep(15 * time.Second) for _, validator := range validators { - status, err := validator.GetStatus(ctx) - logger.Info("validator status", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as droplet isn't available") - - _, err = validator.GetIP(ctx) - require.Error(t, err, "validator IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - logger.Info("node status", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as droplet isn't available") - - _, err = node.GetIP(ctx) - require.Error(t, err, "node IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, node) } } @@ -272,31 +163,10 @@ func TestDOE2E(t *testing.T) { validators := originalChain.GetValidators() nodes := originalChain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := validator.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) - - testFile := "test.txt" - testContent := []byte("test content") - err = validator.WriteFile(ctx, testFile, testContent) - require.NoError(t, err) - - readContent, err := validator.ReadFile(ctx, testFile) - require.NoError(t, err) - require.Equal(t, testContent, readContent) + e2e.AssertNodeRunning(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := node.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) + e2e.AssertNodeRunning(t, ctx, node) } err = originalChain.WaitForBlocks(ctx, 2) @@ -304,7 +174,8 @@ func TestDOE2E(t *testing.T) { } require.NoError(t, restoredProvider.Teardown(ctx)) - time.Sleep(30 * time.Second) + // wait for status to update on DO client side + time.Sleep(15 * time.Second) // Verify all remaining chains are properly torn down for _, chain := range remainingChains { @@ -312,23 +183,11 @@ func TestDOE2E(t *testing.T) { nodes := chain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - logger.Info("validator status after provider teardown", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as droplet isn't available") - - _, err = validator.GetIP(ctx) - require.Error(t, err, "validator IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - logger.Info("node status after provider teardown", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as droplet isn't available") - - _, err = node.GetIP(ctx) - require.Error(t, err, "node IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, node) } } } diff --git a/core/tests/e2e/docker/docker_test.go b/core/tests/e2e/docker/docker_test.go index dc84133a..f5e82798 100644 --- a/core/tests/e2e/docker/docker_test.go +++ b/core/tests/e2e/docker/docker_test.go @@ -3,10 +3,10 @@ package e2e import ( "context" "flag" - "fmt" - "sync" "testing" + "github.com/skip-mev/petri/core/v2/tests/e2e" + "github.com/skip-mev/petri/cosmos/v2/node" "github.com/cosmos/cosmos-sdk/crypto/hd" @@ -64,9 +64,6 @@ func TestDockerE2E(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - p, err := docker.CreateProvider(ctx, logger, "docker_provider") - require.NoError(t, err) - defer func() { dockerClient, err := client.NewClientWithOpts() if err != nil { @@ -79,68 +76,25 @@ func TestDockerE2E(t *testing.T) { } }() - var wg sync.WaitGroup - chainErrors := make(chan error, *numTestChains*2) + p, err := docker.CreateProvider(ctx, logger, "docker_provider") + require.NoError(t, err) + chains := make([]*cosmoschain.Chain, *numTestChains) // Create first half of chains - for i := 0; i < *numTestChains/2; i++ { - wg.Add(1) - go func(index int) { - defer wg.Done() - chainConfig := defaultChainConfig - chainConfig.ChainId = fmt.Sprintf("chain-%d", index) - chainConfig.NumNodes = *numNodes - chainConfig.NumValidators = *numValidators - c, err := cosmoschain.CreateChain(ctx, logger, p, chainConfig, defaultChainOptions) - if err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) - return - } - if err := c.Init(ctx, defaultChainOptions); err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) - return - } - chains[index] = c - }(i) - } - wg.Wait() - require.Empty(t, chainErrors) + defaultChainConfig.NumNodes = *numNodes + defaultChainConfig.NumValidators = *numValidators + e2e.CreateChainsConcurrently(ctx, t, logger, p, 0, *numTestChains/2, chains, defaultChainConfig, defaultChainOptions) + // Restore provider before creating second half of chains serializedProvider, err := p.SerializeProvider(ctx) require.NoError(t, err) restoredProvider, err := docker.RestoreProvider(ctx, logger, serializedProvider) require.NoError(t, err) // Create second half of chains with restored provider - for i := *numTestChains / 2; i < *numTestChains; i++ { - wg.Add(1) - go func(index int) { - defer wg.Done() - chainConfig := defaultChainConfig - chainConfig.ChainId = fmt.Sprintf("chain-%d", index) - chainConfig.NumNodes = *numNodes - chainConfig.NumValidators = *numValidators - c, err := cosmoschain.CreateChain(ctx, logger, restoredProvider, chainConfig, defaultChainOptions) - if err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) - return - } - if err := c.Init(ctx, defaultChainOptions); err != nil { - t.Logf("Chain creation error: %v", err) - chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) - return - } - chains[index] = c - }(i) - } - wg.Wait() - require.Empty(t, chainErrors) + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, chains, defaultChainConfig, defaultChainOptions) - // Serialize and restore all chains with the restored provider restoredChains := make([]*cosmoschain.Chain, *numTestChains) for i := 0; i < *numTestChains; i++ { chainState, err := chains[i].Serialize(ctx, restoredProvider) @@ -162,32 +116,11 @@ func TestDockerE2E(t *testing.T) { nodes := originalChain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := validator.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) - - testFile := "test.txt" - testContent := []byte("test content") - err = validator.WriteFile(ctx, testFile, testContent) - require.NoError(t, err) - - readContent, err := validator.ReadFile(ctx, testFile) - require.NoError(t, err) - require.Equal(t, testContent, readContent) + e2e.AssertNodeRunning(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := node.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) + e2e.AssertNodeRunning(t, ctx, node) } err = originalChain.WaitForBlocks(ctx, 2) @@ -198,23 +131,11 @@ func TestDockerE2E(t *testing.T) { require.NoError(t, err) for _, validator := range validators { - status, err := validator.GetStatus(ctx) - logger.Info("validator status", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as container isn't available") - - _, err = validator.GetIP(ctx) - require.Error(t, err, "validator IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - logger.Info("node status", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as container isn't available") - - _, err = node.GetIP(ctx) - require.Error(t, err, "node IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, node) } } @@ -226,31 +147,10 @@ func TestDockerE2E(t *testing.T) { validators := originalChain.GetValidators() nodes := originalChain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := validator.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) - - testFile := "test.txt" - testContent := []byte("test content") - err = validator.WriteFile(ctx, testFile, testContent) - require.NoError(t, err) - - readContent, err := validator.ReadFile(ctx, testFile) - require.NoError(t, err) - require.Equal(t, testContent, readContent) + e2e.AssertNodeRunning(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - require.NoError(t, err) - require.Equal(t, provider.TASK_RUNNING, status) - - ip, err := node.GetIP(ctx) - require.NoError(t, err) - require.NotEmpty(t, ip) + e2e.AssertNodeRunning(t, ctx, node) } err = originalChain.WaitForBlocks(ctx, 2) @@ -264,23 +164,11 @@ func TestDockerE2E(t *testing.T) { nodes := chain.GetNodes() for _, validator := range validators { - status, err := validator.GetStatus(ctx) - logger.Info("validator status after provider teardown", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "validator task should report undefined as container isn't available") - - _, err = validator.GetIP(ctx) - require.Error(t, err, "validator IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, validator) } for _, node := range nodes { - status, err := node.GetStatus(ctx) - logger.Info("node status after provider teardown", zap.Any("", status)) - require.Error(t, err) - require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node task should report undefined as container isn't available") - - _, err = node.GetIP(ctx) - require.Error(t, err, "node IP should not be accessible after teardown") + e2e.AssertNodeShutdown(t, ctx, node) } } } diff --git a/core/tests/e2e/utils.go b/core/tests/e2e/utils.go new file mode 100644 index 00000000..ac83c07a --- /dev/null +++ b/core/tests/e2e/utils.go @@ -0,0 +1,97 @@ +package e2e + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + + "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func AssertNodeRunning(t *testing.T, ctx context.Context, node types.NodeI) { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = node.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := node.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) +} + +func AssertNodeShutdown(t *testing.T, ctx context.Context, node types.NodeI) { + status, err := node.GetStatus(ctx) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node status should report as undefined after shutdown") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") +} + +func GetExternalIP() (string, error) { + resp, err := http.Get("https://ifconfig.me") + if err != nil { + return "", err + } + defer resp.Body.Close() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return strings.TrimSpace(string(ip)), nil +} + +// CreateChainsConcurrently creates multiple chains concurrently using the provided configuration +func CreateChainsConcurrently( + ctx context.Context, + t *testing.T, + logger *zap.Logger, + p provider.ProviderI, + startIndex, endIndex int, + chains []*cosmoschain.Chain, + chainConfig types.ChainConfig, + chainOptions types.ChainOptions, +) { + var wg sync.WaitGroup + chainErrors := make(chan error, endIndex-startIndex) + + for i := startIndex; i < endIndex; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + config := chainConfig + config.ChainId = fmt.Sprintf("chain-%d", index) + c, err := cosmoschain.CreateChain(ctx, logger, p, config, chainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, chainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) +} From bfbb614ca35fd1fb4da3cdadbadc5f96923ac2a7 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 22 Jan 2025 01:41:52 +0200 Subject: [PATCH 30/50] create -> restore -> create --- core/tests/e2e/digitalocean/do_test.go | 10 +++++----- core/tests/e2e/docker/docker_test.go | 9 +++++---- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/core/tests/e2e/digitalocean/do_test.go b/core/tests/e2e/digitalocean/do_test.go index 20c877a9..9dc2467e 100644 --- a/core/tests/e2e/digitalocean/do_test.go +++ b/core/tests/e2e/digitalocean/do_test.go @@ -104,12 +104,9 @@ func TestDOE2E(t *testing.T) { restoredProvider, err := digitalocean.RestoreProvider(ctx, doToken, serializedProvider, nil, nil) require.NoError(t, err) - // Create second half of chains with restored provider - e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, chains, defaultChainConfig, defaultChainOptions) - - // Serialize and restore all chains with the restored provider + // Restore the existing chains with the restored provider restoredChains := make([]*cosmoschain.Chain, *numTestChains) - for i := 0; i < *numTestChains; i++ { + for i := 0; i < *numTestChains/2; i++ { chainState, err := chains[i].Serialize(ctx, restoredProvider) require.NoError(t, err) @@ -122,6 +119,9 @@ func TestDOE2E(t *testing.T) { restoredChains[i] = restoredChain } + // Create second half of chains with restored provider + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, restoredChains, defaultChainConfig, defaultChainOptions) + // Test and teardown half the chains individually for i := 0; i < *numTestChains/2; i++ { originalChain := chains[i] diff --git a/core/tests/e2e/docker/docker_test.go b/core/tests/e2e/docker/docker_test.go index f5e82798..c4a324d0 100644 --- a/core/tests/e2e/docker/docker_test.go +++ b/core/tests/e2e/docker/docker_test.go @@ -92,11 +92,9 @@ func TestDockerE2E(t *testing.T) { restoredProvider, err := docker.RestoreProvider(ctx, logger, serializedProvider) require.NoError(t, err) - // Create second half of chains with restored provider - e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, chains, defaultChainConfig, defaultChainOptions) - + // restore the existing chains with the restored provider restoredChains := make([]*cosmoschain.Chain, *numTestChains) - for i := 0; i < *numTestChains; i++ { + for i := 0; i < *numTestChains/2; i++ { chainState, err := chains[i].Serialize(ctx, restoredProvider) require.NoError(t, err) @@ -109,6 +107,9 @@ func TestDockerE2E(t *testing.T) { restoredChains[i] = restoredChain } + // Create second half of chains with restored provider + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, restoredChains, defaultChainConfig, defaultChainOptions) + // Test and teardown half the chains individually for i := 0; i < *numTestChains/2; i++ { originalChain := restoredChains[i] From 55384ee1437a8d529aa5020c8b8e822c76a8be5b Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 22 Jan 2025 01:56:10 +0200 Subject: [PATCH 31/50] bug fix --- core/tests/e2e/digitalocean/do_test.go | 4 ++-- core/tests/e2e/docker/docker_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/tests/e2e/digitalocean/do_test.go b/core/tests/e2e/digitalocean/do_test.go index 9dc2467e..bc7b5fbe 100644 --- a/core/tests/e2e/digitalocean/do_test.go +++ b/core/tests/e2e/digitalocean/do_test.go @@ -124,7 +124,7 @@ func TestDOE2E(t *testing.T) { // Test and teardown half the chains individually for i := 0; i < *numTestChains/2; i++ { - originalChain := chains[i] + originalChain := restoredChains[i] validators := originalChain.GetValidators() nodes := originalChain.GetNodes() @@ -158,7 +158,7 @@ func TestDOE2E(t *testing.T) { // Test the remaining chains but let the provider teardown handle their cleanup remainingChains := make([]*cosmoschain.Chain, 0) for i := *numTestChains / 2; i < *numTestChains; i++ { - originalChain := chains[i] + originalChain := restoredChains[i] remainingChains = append(remainingChains, originalChain) validators := originalChain.GetValidators() nodes := originalChain.GetNodes() diff --git a/core/tests/e2e/docker/docker_test.go b/core/tests/e2e/docker/docker_test.go index c4a324d0..e6af6753 100644 --- a/core/tests/e2e/docker/docker_test.go +++ b/core/tests/e2e/docker/docker_test.go @@ -92,7 +92,7 @@ func TestDockerE2E(t *testing.T) { restoredProvider, err := docker.RestoreProvider(ctx, logger, serializedProvider) require.NoError(t, err) - // restore the existing chains with the restored provider + // Restore the existing chains with the restored provider restoredChains := make([]*cosmoschain.Chain, *numTestChains) for i := 0; i < *numTestChains/2; i++ { chainState, err := chains[i].Serialize(ctx, restoredProvider) From ae78d385723ac847aec3aae7ad1a064551989ce2 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Sun, 19 Jan 2025 06:44:15 +0200 Subject: [PATCH 32/50] test: e2e docker test --- core/tests/e2e/docker/docker_test.go | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/core/tests/e2e/docker/docker_test.go b/core/tests/e2e/docker/docker_test.go index e6af6753..b88d1d64 100644 --- a/core/tests/e2e/docker/docker_test.go +++ b/core/tests/e2e/docker/docker_test.go @@ -3,6 +3,9 @@ package e2e import ( "context" "flag" + "fmt" + "os" + "sync" "testing" "github.com/skip-mev/petri/core/v2/tests/e2e" @@ -12,7 +15,9 @@ import ( "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" + gonanoid "github.com/matoous/go-nanoid/v2" "github.com/skip-mev/petri/core/v2/provider" + "github.com/skip-mev/petri/core/v2/provider/digitalocean" "github.com/skip-mev/petri/core/v2/provider/docker" "github.com/skip-mev/petri/core/v2/types" cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" @@ -49,6 +54,17 @@ var ( DerivationFn: hd.Secp256k1.Derive(), GenerationFn: hd.Secp256k1.Generate(), }, + NodeOptions: types.NodeOptions{ + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig types.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": os.Getenv("DO_IMAGE_ID"), + } + def.ProviderSpecificConfig = doConfig + return def + }, + }, } numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") @@ -63,8 +79,13 @@ func TestDockerE2E(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() + providerName := gonanoid.MustGenerate("abcdefghijklqmnoqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", 10) + + p, err := docker.CreateProvider(ctx, logger, providerName) + require.NoError(t, err) defer func() { + require.NoError(t, p.Teardown(ctx)) dockerClient, err := client.NewClientWithOpts() if err != nil { t.Logf("Failed to create Docker client for volume cleanup: %v", err) From 3c5e94fcc7601c71786b10bd374d81d3721359b5 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 14 Jan 2025 23:39:41 +0200 Subject: [PATCH 33/50] genesis functions should use RunCommandWhileStopped --- core/types/chain.go | 5 +++++ cosmos/node/genesis.go | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/core/types/chain.go b/core/types/chain.go index 490ad095..86c8b02e 100644 --- a/core/types/chain.go +++ b/core/types/chain.go @@ -139,3 +139,8 @@ func (c ChainConfig) ValidateBasic() error { return nil } + +func isValidWalletConfig(cfg WalletConfig) bool { + return cfg.Bech32Prefix != "" && cfg.SigningAlgorithm != "" && + cfg.HDPath != nil && cfg.DerivationFn != nil && cfg.GenerationFn != nil +} diff --git a/cosmos/node/genesis.go b/cosmos/node/genesis.go index d6de7aee..7adf3f34 100644 --- a/cosmos/node/genesis.go +++ b/cosmos/node/genesis.go @@ -71,7 +71,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo command = append(command, "add-genesis-account", address, amount) command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommand(ctx, command) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) n.logger.Debug("add-genesis-account", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -101,7 +101,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommand(ctx, command) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) n.logger.Debug("gentx", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -127,7 +127,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error { command = append(command, "collect-gentxs") - stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand(command...)) + stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, n.BinCommand(command...)) n.logger.Debug("collect-gentxs", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { From 0990f41b34801b315a84e3061d3a2df8a9b25709 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 16 Jan 2025 21:37:48 +0200 Subject: [PATCH 34/50] move serializable fields to state --- core/types/chain.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/core/types/chain.go b/core/types/chain.go index 86c8b02e..490ad095 100644 --- a/core/types/chain.go +++ b/core/types/chain.go @@ -139,8 +139,3 @@ func (c ChainConfig) ValidateBasic() error { return nil } - -func isValidWalletConfig(cfg WalletConfig) bool { - return cfg.Bech32Prefix != "" && cfg.SigningAlgorithm != "" && - cfg.HDPath != nil && cfg.DerivationFn != nil && cfg.GenerationFn != nil -} From 9b89f56aedf0a400d41613421c7b73b26e50c01d Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 04:34:57 +0200 Subject: [PATCH 35/50] chore: remove provider reference from task struct --- core/provider/digitalocean/docker.go | 94 --- core/provider/digitalocean/droplet.go | 2 +- .../digitalocean/mocks/docker_client_mock.go | 377 ---------- core/provider/digitalocean/provider.go | 18 +- core/provider/digitalocean/provider_test.go | 19 +- core/provider/digitalocean/task.go | 31 +- core/provider/digitalocean/task_test.go | 51 +- core/provider/docker/provider.go | 42 +- core/provider/docker/task.go | 120 +++- core/provider/docker/volume.go | 78 +- core/provider/docker_client.go | 194 +++++ core/provider/mocks/docker_client_mock.go | 672 ++++++++++++++++++ 12 files changed, 1058 insertions(+), 640 deletions(-) delete mode 100644 core/provider/digitalocean/mocks/docker_client_mock.go create mode 100644 core/provider/docker_client.go create mode 100644 core/provider/mocks/docker_client_mock.go diff --git a/core/provider/digitalocean/docker.go b/core/provider/digitalocean/docker.go index 5b6c1ab6..e69de29b 100644 --- a/core/provider/digitalocean/docker.go +++ b/core/provider/digitalocean/docker.go @@ -1,94 +0,0 @@ -package digitalocean - -import ( - "context" - "io" - - "github.com/docker/docker/api/types" - "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/image" - "github.com/docker/docker/api/types/network" - dockerclient "github.com/docker/docker/client" - specs "github.com/opencontainers/image-spec/specs-go/v1" -) - -// DockerClient is an interface that abstracts Docker functionality -type DockerClient interface { - Ping(ctx context.Context) (types.Ping, error) - ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) - ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) - ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) - ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) - ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error - ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error - ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) - ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) - ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) - ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) - ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error - Close() error -} - -type defaultDockerClient struct { - client *dockerclient.Client -} - -func NewDockerClient(host string) (DockerClient, error) { - client, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(host)) - if err != nil { - return nil, err - } - return &defaultDockerClient{client: client}, nil -} - -func (d *defaultDockerClient) Ping(ctx context.Context) (types.Ping, error) { - return d.client.Ping(ctx) -} - -func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) { - return d.client.ImageInspectWithRaw(ctx, image) -} - -func (d *defaultDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { - return d.client.ImagePull(ctx, ref, options) -} - -func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { - return d.client.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, containerName) -} - -func (d *defaultDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { - return d.client.ContainerList(ctx, options) -} - -func (d *defaultDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { - return d.client.ContainerStart(ctx, containerID, options) -} - -func (d *defaultDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { - return d.client.ContainerStop(ctx, containerID, options) -} - -func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { - return d.client.ContainerInspect(ctx, containerID) -} - -func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, options container.ExecOptions) (types.IDResponse, error) { - return d.client.ContainerExecCreate(ctx, container, options) -} - -func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, options container.ExecStartOptions) (types.HijackedResponse, error) { - return d.client.ContainerExecAttach(ctx, execID, options) -} - -func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { - return d.client.ContainerExecInspect(ctx, execID) -} - -func (d *defaultDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { - return d.client.ContainerRemove(ctx, containerID, options) -} - -func (d *defaultDockerClient) Close() error { - return d.client.Close() -} diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 01d0cd88..36ff76d2 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -75,7 +75,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } if p.dockerClients[ip] == nil { - dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { p.logger.Error("failed to create docker client", zap.Error(err)) return false, err diff --git a/core/provider/digitalocean/mocks/docker_client_mock.go b/core/provider/digitalocean/mocks/docker_client_mock.go deleted file mode 100644 index e27c4705..00000000 --- a/core/provider/digitalocean/mocks/docker_client_mock.go +++ /dev/null @@ -1,377 +0,0 @@ -// Code generated by mockery v2.47.0. DO NOT EDIT. - -package mocks - -import ( - context "context" - - container "github.com/docker/docker/api/types/container" - - image "github.com/docker/docker/api/types/image" - - io "io" - - mock "github.com/stretchr/testify/mock" - - network "github.com/docker/docker/api/types/network" - - types "github.com/docker/docker/api/types" - - v1 "github.com/opencontainers/image-spec/specs-go/v1" -) - -// DockerClient is an autogenerated mock type for the DockerClient type -type DockerClient struct { - mock.Mock -} - -// Close provides a mock function with given fields: -func (_m *DockerClient) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ContainerCreate provides a mock function with given fields: ctx, config, hostConfig, networkingConfig, platform, containerName -func (_m *DockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *v1.Platform, containerName string) (container.CreateResponse, error) { - ret := _m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) - - if len(ret) == 0 { - panic("no return value specified for ContainerCreate") - } - - var r0 container.CreateResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) (container.CreateResponse, error)); ok { - return rf(ctx, config, hostConfig, networkingConfig, platform, containerName) - } - if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) container.CreateResponse); ok { - r0 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) - } else { - r0 = ret.Get(0).(container.CreateResponse) - } - - if rf, ok := ret.Get(1).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) error); ok { - r1 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerExecAttach provides a mock function with given fields: ctx, execID, config -func (_m *DockerClient) ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) { - ret := _m.Called(ctx, execID, config) - - if len(ret) == 0 { - panic("no return value specified for ContainerExecAttach") - } - - var r0 types.HijackedResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) (types.HijackedResponse, error)); ok { - return rf(ctx, execID, config) - } - if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecStartCheck) types.HijackedResponse); ok { - r0 = rf(ctx, execID, config) - } else { - r0 = ret.Get(0).(types.HijackedResponse) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecStartCheck) error); ok { - r1 = rf(ctx, execID, config) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerExecCreate provides a mock function with given fields: ctx, _a1, config -func (_m *DockerClient) ContainerExecCreate(ctx context.Context, _a1 string, config types.ExecConfig) (types.IDResponse, error) { - ret := _m.Called(ctx, _a1, config) - - if len(ret) == 0 { - panic("no return value specified for ContainerExecCreate") - } - - var r0 types.IDResponse - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) (types.IDResponse, error)); ok { - return rf(ctx, _a1, config) - } - if rf, ok := ret.Get(0).(func(context.Context, string, types.ExecConfig) types.IDResponse); ok { - r0 = rf(ctx, _a1, config) - } else { - r0 = ret.Get(0).(types.IDResponse) - } - - if rf, ok := ret.Get(1).(func(context.Context, string, types.ExecConfig) error); ok { - r1 = rf(ctx, _a1, config) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerExecInspect provides a mock function with given fields: ctx, execID -func (_m *DockerClient) ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) { - ret := _m.Called(ctx, execID) - - if len(ret) == 0 { - panic("no return value specified for ContainerExecInspect") - } - - var r0 types.ContainerExecInspect - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerExecInspect, error)); ok { - return rf(ctx, execID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerExecInspect); ok { - r0 = rf(ctx, execID) - } else { - r0 = ret.Get(0).(types.ContainerExecInspect) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, execID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerInspect provides a mock function with given fields: ctx, containerID -func (_m *DockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { - ret := _m.Called(ctx, containerID) - - if len(ret) == 0 { - panic("no return value specified for ContainerInspect") - } - - var r0 types.ContainerJSON - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerJSON, error)); ok { - return rf(ctx, containerID) - } - if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerJSON); ok { - r0 = rf(ctx, containerID) - } else { - r0 = ret.Get(0).(types.ContainerJSON) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, containerID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerList provides a mock function with given fields: ctx, options -func (_m *DockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { - ret := _m.Called(ctx, options) - - if len(ret) == 0 { - panic("no return value specified for ContainerList") - } - - var r0 []types.Container - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) ([]types.Container, error)); ok { - return rf(ctx, options) - } - if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) []types.Container); ok { - r0 = rf(ctx, options) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]types.Container) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, container.ListOptions) error); ok { - r1 = rf(ctx, options) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// ContainerRemove provides a mock function with given fields: ctx, containerID, options -func (_m *DockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { - ret := _m.Called(ctx, containerID, options) - - if len(ret) == 0 { - panic("no return value specified for ContainerRemove") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, container.RemoveOptions) error); ok { - r0 = rf(ctx, containerID, options) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ContainerStart provides a mock function with given fields: ctx, containerID, options -func (_m *DockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { - ret := _m.Called(ctx, containerID, options) - - if len(ret) == 0 { - panic("no return value specified for ContainerStart") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, container.StartOptions) error); ok { - r0 = rf(ctx, containerID, options) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ContainerStop provides a mock function with given fields: ctx, containerID, options -func (_m *DockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { - ret := _m.Called(ctx, containerID, options) - - if len(ret) == 0 { - panic("no return value specified for ContainerStop") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string, container.StopOptions) error); ok { - r0 = rf(ctx, containerID, options) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// ImageInspectWithRaw provides a mock function with given fields: ctx, _a1 -func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, _a1 string) (types.ImageInspect, []byte, error) { - ret := _m.Called(ctx, _a1) - - if len(ret) == 0 { - panic("no return value specified for ImageInspectWithRaw") - } - - var r0 types.ImageInspect - var r1 []byte - var r2 error - if rf, ok := ret.Get(0).(func(context.Context, string) (types.ImageInspect, []byte, error)); ok { - return rf(ctx, _a1) - } - if rf, ok := ret.Get(0).(func(context.Context, string) types.ImageInspect); ok { - r0 = rf(ctx, _a1) - } else { - r0 = ret.Get(0).(types.ImageInspect) - } - - if rf, ok := ret.Get(1).(func(context.Context, string) []byte); ok { - r1 = rf(ctx, _a1) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).([]byte) - } - } - - if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { - r2 = rf(ctx, _a1) - } else { - r2 = ret.Error(2) - } - - return r0, r1, r2 -} - -// ImagePull provides a mock function with given fields: ctx, ref, options -func (_m *DockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { - ret := _m.Called(ctx, ref, options) - - if len(ret) == 0 { - panic("no return value specified for ImagePull") - } - - var r0 io.ReadCloser - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) (io.ReadCloser, error)); ok { - return rf(ctx, ref, options) - } - if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) io.ReadCloser); ok { - r0 = rf(ctx, ref, options) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(io.ReadCloser) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, image.PullOptions) error); ok { - r1 = rf(ctx, ref, options) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// Ping provides a mock function with given fields: ctx -func (_m *DockerClient) Ping(ctx context.Context) (types.Ping, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Ping") - } - - var r0 types.Ping - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (types.Ping, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) types.Ping); ok { - r0 = rf(ctx) - } else { - r0 = ret.Get(0).(types.Ping) - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// NewDockerClient creates a new instance of DockerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewDockerClient(t interface { - mock.TestingT - Cleanup(func()) -}) *DockerClient { - mock := &DockerClient{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 12fa8529..d5ff2741 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -43,7 +43,11 @@ type Provider struct { logger *zap.Logger doClient DoClient - dockerClients map[string]DockerClient // map of droplet ip address to docker clients + petriTag string + userIPs []string + sshKeyPair *SSHKeyPair + firewallID string + dockerClients map[string]provider.DockerClient // map of droplet ip address to docker clients } // NewProvider creates a provider that implements the Provider interface for DigitalOcean. @@ -55,7 +59,7 @@ func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, t // NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. // This is primarily used for testing. -func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { +func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]provider.DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { var err error if sshKeyPair == nil { sshKeyPair, err = MakeSSHKeyPair() @@ -72,7 +76,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName userIPs = append(userIPs, additionalUserIPS...) if dockerClients == nil { - dockerClients = make(map[string]DockerClient) + dockerClients = make(map[string]provider.DockerClient) } petriTag := fmt.Sprintf("petri-droplet-%s", util.RandomString(5)) @@ -152,7 +156,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin dockerClient := p.dockerClients[ip] if dockerClient == nil { - dockerClient, err = NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err = provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, err } @@ -222,7 +226,8 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin return &Task{ state: taskState, - provider: p, + removeTask: p.removeTask, + sshKeyPair: p.state.SSHKeyPair, logger: p.logger.With(zap.String("task", definition.Name)), doClient: p.doClient, dockerClient: dockerClient, @@ -318,7 +323,8 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas } task := &Task{ - state: &taskState, + state: &taskState, + removeTask: p.removeTask, } if err := p.initializeDeserializedTask(ctx, task); err != nil { diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index bfe04c08..3bcca6f5 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -3,6 +3,7 @@ package digitalocean import ( "context" "fmt" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" "io" "strings" "sync" @@ -22,14 +23,14 @@ import ( "go.uber.org/zap" "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + dockerMocks "github.com/skip-mev/petri/core/v2/provider/mocks" "github.com/skip-mev/petri/core/v2/util" ) -func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoClient, *mocks.DockerClient) { +func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoClient, *dockerMocks.DockerClient) { logger := zap.NewExample() mockDO := mocks.NewDoClient(t) - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDocker.On("Ping", ctx).Return(types.Ping{}, nil) mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) @@ -59,7 +60,7 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]DockerClient{ + mockDockerClients := map[string]provider.DockerClient{ "10.0.0.1": mockDocker, } @@ -135,14 +136,14 @@ func TestCreateTask_ValidTask(t *testing.T) { func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { logger := zap.NewExample() mockDO := mocks.NewDoClient(t) - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]DockerClient{ + mockDockerClients := map[string]provider.DockerClient{ "10.0.0.1": mockDocker, } @@ -259,12 +260,12 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { defer cancel() logger, _ := zap.NewDevelopment() - mockDockerClients := make(map[string]DockerClient) + mockDockerClients := make(map[string]provider.DockerClient) mockDO := mocks.NewDoClient(t) for i := 0; i < 10; i++ { ip := fmt.Sprintf("10.0.0.%d", i+1) - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDockerClients[ip] = mockDocker mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() @@ -447,7 +448,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { mockDO.AssertExpectations(t) for _, client := range mockDockerClients { - client.(*mocks.DockerClient).AssertExpectations(t) + client.(*dockerMocks.DockerClient).AssertExpectations(t) } } diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 156e887a..f2f84878 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "io" "net" "path" "sync" @@ -13,11 +12,9 @@ import ( "golang.org/x/crypto/ssh" "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" dockerclient "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" - "github.com/pkg/errors" "github.com/pkg/sftp" "github.com/spf13/afero" "github.com/spf13/afero/sftpfs" @@ -36,15 +33,18 @@ type TaskState struct { SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` } +// RemoveTaskFunc is a callback function type for removing a task from its provider +type RemoveTaskFunc func(ctx context.Context, taskID int) error + type Task struct { state *TaskState stateMu sync.Mutex - provider *Provider + removeTask RemoveTaskFunc logger *zap.Logger sshClient *ssh.Client doClient DoClient - dockerClient DockerClient + dockerClient provider.DockerClient } var _ provider.TaskI = (*Task)(nil) @@ -128,8 +128,7 @@ func (t *Task) Destroy(ctx context.Context) error { return err } - // TODO(nadim-az): remove reference to provider in Task struct - if err := t.provider.removeTask(ctx, t.GetState().ID); err != nil { + if err := t.removeTask(ctx, t.GetState().ID); err != nil { return err } return nil @@ -450,7 +449,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string return stdout.String(), stderr.String(), exitCode, nil } -func startContainerWithBlock(ctx context.Context, dockerClient DockerClient, containerID string) error { +func startContainerWithBlock(ctx context.Context, dockerClient provider.DockerClient, containerID string) error { // start container if err := dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { return err @@ -481,19 +480,3 @@ func startContainerWithBlock(ctx context.Context, dockerClient DockerClient, con } } } - -func pullImage(ctx context.Context, dockerClient DockerClient, logger *zap.Logger, img string) error { - logger.Info("pulling image", zap.String("image", img)) - resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) - if err != nil { - return errors.Wrap(err, "failed to pull docker image") - } - - defer resp.Close() - // throw away the image pull stdout response - _, err = io.Copy(io.Discard, resp) - if err != nil { - return errors.Wrap(err, "failed to pull docker image") - } - return nil -} diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index b0b61098..c111692e 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -16,6 +16,8 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" "github.com/docker/docker/api/types/network" + dockerMocks "github.com/skip-mev/petri/core/v2/provider/mocks" + specs "github.com/opencontainers/image-spec/specs-go/v1" "github.com/skip-mev/petri/core/v2/provider" "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" @@ -56,7 +58,7 @@ func TestTaskLifecycle(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) droplet := &godo.Droplet{ @@ -129,7 +131,7 @@ func TestTaskRunCommand(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) @@ -215,7 +217,7 @@ func TestTaskRunCommandWhileStopped(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) createResp := container.CreateResponse{ID: testContainerID} @@ -333,7 +335,7 @@ func TestTaskGetIP(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) expectedIP := "1.2.3.4" @@ -378,7 +380,7 @@ func TestTaskDestroy(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) mockDO.On("GetDroplet", ctx, testDroplet.ID).Return(testDroplet, nil) @@ -401,7 +403,10 @@ func TestTaskDestroy(t *testing.T) { logger: logger, dockerClient: mockDocker, doClient: mockDO, - provider: provider, + removeTask: func(ctx context.Context, taskID int) error { + delete(provider.state.TaskStates, taskID) + return nil + }, } providerState.TaskStates[task.GetState().ID] = task.state @@ -418,7 +423,7 @@ func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) @@ -525,7 +530,7 @@ func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) @@ -629,7 +634,7 @@ func TestTaskExposingPort(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) droplet := &godo.Droplet{ @@ -728,7 +733,7 @@ func TestGetStatus(t *testing.T) { name string dropletStatus string containerState string - setupMocks func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) + setupMocks func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) expectedStatus provider.TaskStatus expectError bool }{ @@ -736,7 +741,7 @@ func TestGetStatus(t *testing.T) { name: "droplet not active", dropletStatus: "off", containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletOff.ID).Return(testDropletOff, nil) }, expectedStatus: provider.TASK_STOPPED, @@ -746,7 +751,7 @@ func TestGetStatus(t *testing.T) { name: "container running", dropletStatus: "active", containerState: "running", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -766,7 +771,7 @@ func TestGetStatus(t *testing.T) { name: "container paused", dropletStatus: "active", containerState: "paused", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -786,7 +791,7 @@ func TestGetStatus(t *testing.T) { name: "container stopped state", dropletStatus: "active", containerState: "exited", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -806,7 +811,7 @@ func TestGetStatus(t *testing.T) { name: "container removing", dropletStatus: "active", containerState: "removing", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -826,7 +831,7 @@ func TestGetStatus(t *testing.T) { name: "container dead", dropletStatus: "active", containerState: "dead", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -846,7 +851,7 @@ func TestGetStatus(t *testing.T) { name: "container created", dropletStatus: "active", containerState: "created", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -866,7 +871,7 @@ func TestGetStatus(t *testing.T) { name: "unknown container status", dropletStatus: "active", containerState: "unknown_status", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -886,7 +891,7 @@ func TestGetStatus(t *testing.T) { name: "no containers found", dropletStatus: "active", containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -899,7 +904,7 @@ func TestGetStatus(t *testing.T) { name: "container inspect error", dropletStatus: "active", containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -913,7 +918,7 @@ func TestGetStatus(t *testing.T) { name: "getDroplet error", dropletStatus: "", containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, 123).Return(nil, fmt.Errorf("failed to get droplet")) }, expectedStatus: provider.TASK_STATUS_UNDEFINED, @@ -923,7 +928,7 @@ func TestGetStatus(t *testing.T) { name: "containerList error", dropletStatus: "active", containerState: "", - setupMocks: func(mockDocker *mocks.DockerClient, mockDO *mocks.DoClient) { + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) mockDocker.On("ContainerList", ctx, container.ListOptions{ Limit: 1, @@ -936,7 +941,7 @@ func TestGetStatus(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO := mocks.NewDoClient(t) tt.setupMocks(mockDocker, mockDO) diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index f0636305..bc2793e2 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -38,7 +38,7 @@ type Provider struct { state *ProviderState stateMu sync.Mutex - dockerClient *client.Client + dockerClient provider.DockerClient dockerNetworkAllocator *ipallocator.Range networkMu sync.Mutex logger *zap.Logger @@ -120,7 +120,7 @@ func RestoreProvider(ctx context.Context, logger *zap.Logger, state []byte) (*Pr logger: logger, } - dockerClient, err := client.NewClientWithOpts() + dockerClient, err := provider.NewDockerClient("") if err != nil { return nil, err } @@ -167,13 +167,14 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin } taskState := &TaskState{ - Name: definition.Name, - Definition: definition, + Name: definition.Name, + Definition: definition, + BuilderImageName: p.state.BuilderImageName, } logger := p.logger.Named("docker_provider") - if err := p.pullImage(ctx, definition.Image.Image); err != nil { + if err := provider.PullImage(ctx, p.dockerClient, logger, definition.Image.Image); err != nil { return nil, err } @@ -260,6 +261,8 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin taskState.Id = createdContainer.ID taskState.Status = provider.TASK_STOPPED + taskState.NetworkName = p.state.NetworkName + taskState.ProviderName = p.state.Name taskState.IpAddress = ip p.stateMu.Lock() @@ -268,8 +271,10 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin p.state.TaskStates[taskState.Id] = taskState return &Task{ - state: taskState, - provider: p, + state: taskState, + logger: p.logger.With(zap.String("task", definition.Name)), + dockerClient: p.dockerClient, + removeTask: p.removeTask, }, nil } @@ -307,8 +312,10 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas } task := &Task{ - provider: p, - state: &taskState, + state: &taskState, + logger: p.logger.With(zap.String("task", taskState.Name)), + dockerClient: p.dockerClient, + removeTask: p.removeTask, } if err := task.ensureTask(ctx); err != nil { @@ -327,23 +334,6 @@ func (p *Provider) removeTask(_ context.Context, taskID string) error { return nil } -func (p *Provider) pullImage(ctx context.Context, imageName string) error { - _, _, err := p.dockerClient.ImageInspectWithRaw(ctx, imageName) - if err != nil { - p.logger.Info("image not found, pulling", zap.String("image", imageName)) - resp, err := p.dockerClient.ImagePull(ctx, imageName, image.PullOptions{}) - if err != nil { - return err - } - defer resp.Close() - - // throw away the image pull stdout response - _, err = io.Copy(io.Discard, resp) - return err - } - return nil -} - func (p *Provider) Teardown(ctx context.Context) error { p.logger.Info("tearing down Docker provider") diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index ea569f8c..3b9d797e 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -8,6 +8,7 @@ import ( "time" "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" "github.com/docker/docker/pkg/stdcopy" "github.com/docker/go-connections/nat" "github.com/skip-mev/petri/core/v3/provider" @@ -21,7 +22,10 @@ type TaskState struct { Volume *VolumeState `json:"volumes"` Definition provider.TaskDefinition `json:"definition"` Status provider.TaskStatus `json:"status"` - IpAddress string `json:"ip_address"` + IpAddress string `json:"ip_address"` + BuilderImageName string `json:"builder_image_name"` + ProviderName string `json:"provider_name"` + NetworkName string `json:"network_name"` } type VolumeState struct { @@ -30,18 +34,19 @@ type VolumeState struct { } type Task struct { - state *TaskState - stateMu sync.Mutex - provider *Provider + state *TaskState + stateMu sync.Mutex + logger *zap.Logger + dockerClient provider.DockerClient + removeTask func(ctx context.Context, taskID string) error } var _ provider.TaskI = (*Task)(nil) func (t *Task) Start(ctx context.Context) error { - t.provider.logger.Info("starting task", zap.String("id", t.state.Id)) - - err := t.provider.dockerClient.ContainerStart(ctx, t.state.Id, container.StartOptions{}) + t.logger.Info("starting task", zap.String("id", t.state.Id)) + err := t.dockerClient.ContainerStart(ctx, t.state.Id, container.StartOptions{}) if err != nil { return err } @@ -59,10 +64,9 @@ func (t *Task) Start(ctx context.Context) error { } func (t *Task) Stop(ctx context.Context) error { - t.provider.logger.Info("stopping task", zap.String("id", t.state.Id)) - - err := t.provider.dockerClient.ContainerStop(ctx, t.state.Id, container.StopOptions{}) + t.logger.Info("stopping task", zap.String("id", t.state.Id)) + err := t.dockerClient.ContainerStop(ctx, t.state.Id, container.StopOptions{}) if err != nil { return err } @@ -80,9 +84,9 @@ func (t *Task) Stop(ctx context.Context) error { } func (t *Task) Destroy(ctx context.Context) error { - t.provider.logger.Info("destroying task", zap.String("id", t.state.Id)) + t.logger.Info("destroying task", zap.String("id", t.state.Id)) - err := t.provider.dockerClient.ContainerRemove(ctx, t.state.Id, container.RemoveOptions{ + err := t.dockerClient.ContainerRemove(ctx, t.state.Id, container.RemoveOptions{ Force: true, RemoveVolumes: true, }) @@ -91,7 +95,7 @@ func (t *Task) Destroy(ctx context.Context) error { return err } - if err := t.provider.removeTask(ctx, t.state.Id); err != nil { + if err := t.removeTask(ctx, t.state.Id); err != nil { return err } @@ -99,16 +103,14 @@ func (t *Task) Destroy(ctx context.Context) error { } func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { - t.provider.logger.Debug("getting external address", zap.String("id", t.state.Id)) - - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + t.logger.Debug("getting external address", zap.String("id", t.state.Id)) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) if err != nil { return "", fmt.Errorf("failed to inspect container: %w", err) } portBindings, ok := dockerContainer.NetworkSettings.Ports[nat.Port(fmt.Sprintf("%s/tcp", port))] - if !ok || len(portBindings) == 0 { return "", fmt.Errorf("port %s not found", port) } @@ -117,15 +119,14 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err } func (t *Task) GetIP(ctx context.Context) (string, error) { - t.provider.logger.Debug("getting IP", zap.String("id", t.state.Id)) + t.logger.Debug("getting IP", zap.String("id", t.state.Id)) - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) if err != nil { return "", err } - ip := dockerContainer.NetworkSettings.Networks[t.provider.state.NetworkName].IPAMConfig.IPv4Address - + ip := dockerContainer.NetworkSettings.Networks[t.state.NetworkName].IPAMConfig.IPv4Address return ip, nil } @@ -149,7 +150,7 @@ func (t *Task) WaitForStatus(ctx context.Context, interval time.Duration, desire } func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { - containerJSON, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + containerJSON, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -162,7 +163,7 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { case "paused": return provider.TASK_PAUSED, nil case "restarting": - return provider.TASK_RUNNING, nil // todo(zygimantass): is this sane? + return provider.TASK_RESTARTING, nil case "removing": return provider.TASK_STOPPED, nil case "exited": @@ -192,15 +193,15 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in } func (t *Task) runCommand(ctx context.Context, cmd []string) (string, string, int, error) { - t.provider.logger.Debug("running command", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + t.logger.Debug("running command", zap.String("id", t.state.Id), zap.Strings("command", cmd)) - exec, err := t.provider.dockerClient.ContainerExecCreate(ctx, t.state.Id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, t.state.Id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: cmd, }) if err != nil { - if buf, err := t.provider.dockerClient.ContainerLogs(ctx, t.state.Id, container.LogsOptions{ + if buf, err := t.dockerClient.ContainerLogs(ctx, t.state.Id, container.LogsOptions{ ShowStdout: true, ShowStderr: true, }); err == nil { @@ -213,7 +214,7 @@ func (t *Task) runCommand(ctx context.Context, cmd []string) (string, string, in return "", "", 0, err } - resp, err := t.provider.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { return "", "", 0, err } @@ -232,7 +233,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := t.provider.dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -247,7 +248,7 @@ loop: } if err != nil { - t.provider.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.state.Id)) + t.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.state.Id)) return "", "", lastExitCode, err } @@ -266,7 +267,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - t.provider.logger.Debug("running command while stopped", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + t.logger.Debug("running command while stopped", zap.String("id", t.state.Id), zap.Strings("command", cmd)) status, err := t.GetStatus(ctx) if err != nil { @@ -282,24 +283,65 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", definition.Name, util.RandomString(5), time.Now().Unix()) definition.Ports = []string{} - task, err := t.provider.CreateTask(ctx, definition) - if err != nil { - return "", "", 0, err + containerConfig := &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Labels: map[string]string{ + providerLabelName: t.state.ProviderName, + }, + Env: convertEnvMapToList(definition.Environment), + } + + var mounts []mount.Mount + if t.state.Volume != nil { + mounts = []mount.Mount{ + { + Type: mount.TypeVolume, + Source: t.state.Volume.Name, + Target: definition.DataDir, + }, + } } - err = task.Start(ctx) - defer task.Destroy(ctx) // nolint:errcheck + hostConfig := &container.HostConfig{ + NetworkMode: container.NetworkMode("host"), + Mounts: mounts, + } + resp, err := t.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, definition.ContainerName) if err != nil { return "", "", 0, err } - stdout, stderr, exitCode, err := task.RunCommand(ctx, cmd) + tempTask := &Task{ + state: &TaskState{ + Id: resp.ID, + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + ProviderName: t.state.ProviderName, + NetworkName: t.state.NetworkName, + }, + logger: t.logger.With(zap.String("temp_task", definition.Name)), + dockerClient: t.dockerClient, + removeTask: t.removeTask, + } + + err = tempTask.Start(ctx) if err != nil { return "", "", 0, err } - return stdout, stderr, exitCode, nil + defer func() { + if err := tempTask.Destroy(ctx); err != nil { + t.logger.Error("failed to destroy temporary task", zap.Error(err)) + } + }() + + return tempTask.RunCommand(ctx, cmd) } func (t *Task) GetState() TaskState { @@ -311,7 +353,7 @@ func (t *Task) GetState() TaskState { func (t *Task) ensureTask(ctx context.Context) error { state := t.GetState() - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return fmt.Errorf("failed to inspect container: %w", err) } @@ -343,7 +385,7 @@ func (t *Task) ensureVolume(ctx context.Context) error { return nil } - volume, err := t.provider.dockerClient.VolumeInspect(ctx, t.state.Volume.Name) + volume, err := t.dockerClient.VolumeInspect(ctx, t.state.Volume.Name) if err != nil { return fmt.Errorf("failed to inspect volume: %w", err) } diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 1cd0333c..29d40f82 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -56,7 +56,6 @@ func (p *Provider) DestroyVolume(ctx context.Context, id string) error { func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -64,7 +63,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", state.Id), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", state.Id), zap.String("path", relPath)) logger.Debug("writing file") @@ -72,16 +71,16 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { return err } logger.Debug("creating writefile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Entrypoint: []string{"sh", "-c"}, Cmd: []string{ @@ -94,7 +93,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string }, Labels: map[string]string{ - providerLabelName: providerState.Name, + providerLabelName: t.state.ProviderName, }, // Use root user to avoid permission issues when reading files from the volume. @@ -121,12 +120,12 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string return } - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed to remove writefile container", zap.String("id", cc.ID), zap.Error(err)) @@ -143,7 +142,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string defer file.Close() - if err := t.provider.dockerClient.CopyToContainer( + if err := t.dockerClient.CopyToContainer( ctx, cc.ID, mountPath, @@ -154,11 +153,11 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string } logger.Debug("starting writefile container") - if err := t.provider.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { + if err := t.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { return fmt.Errorf("starting write-file container: %w", err) } - waitCh, errCh := t.provider.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) + waitCh, errCh := t.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) select { case <-ctx.Done(): return ctx.Err() @@ -182,7 +181,6 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string // taken from strangelove-ventures/interchain-test func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -190,7 +188,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) logger.Debug("writing file") @@ -198,16 +196,16 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { return err } logger.Debug("creating writefile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Entrypoint: []string{"sh", "-c"}, Cmd: []string{ @@ -220,7 +218,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er }, Labels: map[string]string{ - providerLabelName: providerState.Name, + providerLabelName: state.ProviderName, }, // Use root user to avoid permission issues when reading files from the volume. @@ -247,12 +245,12 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er return } - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed to remove writefile container", zap.String("id", cc.ID), zap.Error(err)) @@ -282,7 +280,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er logger.Debug("copying file to container") - if err := t.provider.dockerClient.CopyToContainer( + if err := t.dockerClient.CopyToContainer( ctx, cc.ID, mountPath, @@ -293,11 +291,11 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er } logger.Debug("starting writefile container") - if err := t.provider.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { + if err := t.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { return fmt.Errorf("starting write-file container: %w", err) } - waitCh, errCh := t.provider.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) + waitCh, errCh := t.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) select { case <-ctx.Done(): return ctx.Err() @@ -320,7 +318,6 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return nil, fmt.Errorf("no volumes found for container %s", state.Id) @@ -328,25 +325,25 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) const mountPath = "/mnt/dockervolume" containerName := fmt.Sprintf("petri-getfile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { return nil, err } logger.Debug("creating getfile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Labels: map[string]string{ - providerLabelName: providerState.Name, + providerLabelName: state.ProviderName, }, // Use root user to avoid permission issues when reading files from the volume. @@ -367,12 +364,12 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { logger.Debug("created getfile container", zap.String("id", cc.ID)) defer func() { - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed cleaning up the getfile container", zap.Error(err)) @@ -380,7 +377,7 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { }() logger.Debug("copying from container") - rc, _, err := t.provider.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) + rc, _, err := t.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) if err != nil { return nil, fmt.Errorf("copying from container: %w", err) } @@ -410,7 +407,6 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -418,7 +414,7 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath), zap.String("localPath", localPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath), zap.String("localPath", localPath)) const mountPath = "/mnt/dockervolume" @@ -426,17 +422,17 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error logger.Debug("creating getdir container") - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { return err } - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Labels: map[string]string{ - providerLabelName: providerState.Name, + providerLabelName: state.ProviderName, }, // Use root user to avoid permission issues when reading files from the volume. @@ -455,11 +451,11 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error } defer func() { - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed cleaning up the getdir container", zap.Error(err)) @@ -467,7 +463,7 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error }() logger.Debug("copying from container") - reader, _, err := t.provider.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) + reader, _, err := t.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) if err != nil { return err } @@ -513,7 +509,7 @@ func (p *Provider) SetVolumeOwner(ctx context.Context, volumeName, uid, gid stri containerName := fmt.Sprintf("petri-setowner-%d", time.Now().UnixNano()) - if err := p.pullImage(ctx, p.GetState().BuilderImageName); err != nil { + if err := provider.PullImage(ctx, p.dockerClient, p.logger, p.GetState().BuilderImageName); err != nil { return err } diff --git a/core/provider/docker_client.go b/core/provider/docker_client.go new file mode 100644 index 00000000..7804a97d --- /dev/null +++ b/core/provider/docker_client.go @@ -0,0 +1,194 @@ +package provider + +import ( + "context" + "fmt" + "io" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/api/types/volume" + dockerclient "github.com/docker/docker/client" + specs "github.com/opencontainers/image-spec/specs-go/v1" + "go.uber.org/zap" +) + +// DockerClient is a unified interface for interacting with Docker +// It combines functionality needed by both the Docker and DigitalOcean providers +type DockerClient interface { + // Container Operations + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) + ContainerStart(ctx context.Context, container string, options container.StartOptions) error + ContainerStop(ctx context.Context, container string, options container.StopOptions) error + ContainerRemove(ctx context.Context, container string, options container.RemoveOptions) error + ContainerInspect(ctx context.Context, container string) (types.ContainerJSON, error) + ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) + ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) + + // Container Exec Operations + ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) + ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) + ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) + + // Container File Operations + CopyToContainer(ctx context.Context, container, path string, content io.Reader, options container.CopyToContainerOptions) error + CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, container.PathStat, error) + ContainerLogs(ctx context.Context, container string, options container.LogsOptions) (io.ReadCloser, error) + + // Image Operations + ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) + ImagePull(ctx context.Context, refStr string, options image.PullOptions) (io.ReadCloser, error) + + // Volume Operations + VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) + VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) + VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) + VolumeRemove(ctx context.Context, volumeID string, force bool) error + + // Network Operations + NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) + NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) + NetworkRemove(ctx context.Context, networkID string) error + + // System Operations + Ping(ctx context.Context) (types.Ping, error) + Close() error +} + +// defaultDockerClient is the default implementation of DockerClient interface +type defaultDockerClient struct { + client *dockerclient.Client +} + +func NewDockerClient(host string) (DockerClient, error) { + // If host is empty, use default Docker socket + if host == "" { + client, err := dockerclient.NewClientWithOpts() + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil + } + + host = fmt.Sprintf("tcp://%s:2375", host) + + client, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(host)) + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil +} + +func (d *defaultDockerClient) Ping(ctx context.Context) (types.Ping, error) { + return d.client.Ping(ctx) +} + +func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) { + return d.client.ImageInspectWithRaw(ctx, image) +} + +func (d *defaultDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { + return d.client.ImagePull(ctx, ref, options) +} + +func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { + return d.client.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, containerName) +} + +func (d *defaultDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + return d.client.ContainerList(ctx, options) +} + +func (d *defaultDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + return d.client.ContainerStart(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + return d.client.ContainerStop(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + return d.client.ContainerInspect(ctx, containerID) +} + +func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, config types.ExecConfig) (types.IDResponse, error) { + return d.client.ContainerExecCreate(ctx, container, config) +} + +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) { + return d.client.ContainerExecAttach(ctx, execID, config) +} + +func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) { + return d.client.ContainerExecInspect(ctx, execID) +} + +func (d *defaultDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + return d.client.ContainerRemove(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) { + return d.client.ContainerWait(ctx, containerID, condition) +} + +func (d *defaultDockerClient) ContainerLogs(ctx context.Context, container string, options container.LogsOptions) (io.ReadCloser, error) { + return d.client.ContainerLogs(ctx, container, options) +} + +func (d *defaultDockerClient) CopyToContainer(ctx context.Context, container, path string, content io.Reader, options container.CopyToContainerOptions) error { + return d.client.CopyToContainer(ctx, container, path, content, options) +} + +func (d *defaultDockerClient) CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, types.ContainerPathStat, error) { + return d.client.CopyFromContainer(ctx, container, srcPath) +} + +func (d *defaultDockerClient) VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) { + return d.client.VolumeCreate(ctx, options) +} + +func (d *defaultDockerClient) VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) { + return d.client.VolumeInspect(ctx, volumeID) +} + +func (d *defaultDockerClient) VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) { + return d.client.VolumeList(ctx, options) +} + +func (d *defaultDockerClient) VolumeRemove(ctx context.Context, volumeID string, force bool) error { + return d.client.VolumeRemove(ctx, volumeID, force) +} + +func (d *defaultDockerClient) NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) { + return d.client.NetworkCreate(ctx, name, options) +} + +func (d *defaultDockerClient) NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) { + return d.client.NetworkInspect(ctx, networkID, options) +} + +func (d *defaultDockerClient) NetworkRemove(ctx context.Context, networkID string) error { + return d.client.NetworkRemove(ctx, networkID) +} + +func (d *defaultDockerClient) Close() error { + return d.client.Close() +} + +func PullImage(ctx context.Context, dockerClient DockerClient, logger *zap.Logger, img string) error { + logger.Info("pulling image", zap.String("image", img)) + resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + + defer resp.Close() + // throw away the image pull stdout response + _, err = io.Copy(io.Discard, resp) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + return nil +} diff --git a/core/provider/mocks/docker_client_mock.go b/core/provider/mocks/docker_client_mock.go new file mode 100644 index 00000000..2aab7612 --- /dev/null +++ b/core/provider/mocks/docker_client_mock.go @@ -0,0 +1,672 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + container "github.com/docker/docker/api/types/container" + + image "github.com/docker/docker/api/types/image" + + io "io" + + mock "github.com/stretchr/testify/mock" + + network "github.com/docker/docker/api/types/network" + + types "github.com/docker/docker/api/types" + + v1 "github.com/opencontainers/image-spec/specs-go/v1" + + volume "github.com/docker/docker/api/types/volume" +) + +// DockerClient is an autogenerated mock type for the DockerClient type +type DockerClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DockerClient) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerCreate provides a mock function with given fields: ctx, config, hostConfig, networkingConfig, platform, containerName +func (_m *DockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *v1.Platform, containerName string) (container.CreateResponse, error) { + ret := _m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) + + if len(ret) == 0 { + panic("no return value specified for ContainerCreate") + } + + var r0 container.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) (container.CreateResponse, error)); ok { + return rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) container.CreateResponse); ok { + r0 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r0 = ret.Get(0).(container.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) error); ok { + r1 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecAttach provides a mock function with given fields: ctx, execID, config +func (_m *DockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) { + ret := _m.Called(ctx, execID, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecAttach") + } + + var r0 types.HijackedResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecStartOptions) (types.HijackedResponse, error)); ok { + return rf(ctx, execID, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecStartOptions) types.HijackedResponse); ok { + r0 = rf(ctx, execID, config) + } else { + r0 = ret.Get(0).(types.HijackedResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.ExecStartOptions) error); ok { + r1 = rf(ctx, execID, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecCreate provides a mock function with given fields: ctx, _a1, config +func (_m *DockerClient) ContainerExecCreate(ctx context.Context, _a1 string, config container.ExecOptions) (types.IDResponse, error) { + ret := _m.Called(ctx, _a1, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecCreate") + } + + var r0 types.IDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecOptions) (types.IDResponse, error)); ok { + return rf(ctx, _a1, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecOptions) types.IDResponse); ok { + r0 = rf(ctx, _a1, config) + } else { + r0 = ret.Get(0).(types.IDResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.ExecOptions) error); ok { + r1 = rf(ctx, _a1, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecInspect provides a mock function with given fields: ctx, execID +func (_m *DockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { + ret := _m.Called(ctx, execID) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecInspect") + } + + var r0 container.ExecInspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (container.ExecInspect, error)); ok { + return rf(ctx, execID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) container.ExecInspect); ok { + r0 = rf(ctx, execID) + } else { + r0 = ret.Get(0).(container.ExecInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, execID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerInspect provides a mock function with given fields: ctx, _a1 +func (_m *DockerClient) ContainerInspect(ctx context.Context, _a1 string) (types.ContainerJSON, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for ContainerInspect") + } + + var r0 types.ContainerJSON + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerJSON, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerJSON); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(types.ContainerJSON) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerList provides a mock function with given fields: ctx, options +func (_m *DockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerList") + } + + var r0 []types.Container + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) ([]types.Container, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) []types.Container); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Container) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, container.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerLogs provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerLogs(ctx context.Context, _a1 string, options container.LogsOptions) (io.ReadCloser, error) { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerLogs") + } + + var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.LogsOptions) (io.ReadCloser, error)); ok { + return rf(ctx, _a1, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.LogsOptions) io.ReadCloser); ok { + r0 = rf(ctx, _a1, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.LogsOptions) error); ok { + r1 = rf(ctx, _a1, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerRemove provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerRemove(ctx context.Context, _a1 string, options container.RemoveOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.RemoveOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStart provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerStart(ctx context.Context, _a1 string, options container.StartOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StartOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStop provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerStop(ctx context.Context, _a1 string, options container.StopOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStop") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StopOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerWait provides a mock function with given fields: ctx, containerID, condition +func (_m *DockerClient) ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) { + ret := _m.Called(ctx, containerID, condition) + + if len(ret) == 0 { + panic("no return value specified for ContainerWait") + } + + var r0 <-chan container.WaitResponse + var r1 <-chan error + if rf, ok := ret.Get(0).(func(context.Context, string, container.WaitCondition) (<-chan container.WaitResponse, <-chan error)); ok { + return rf(ctx, containerID, condition) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.WaitCondition) <-chan container.WaitResponse); ok { + r0 = rf(ctx, containerID, condition) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan container.WaitResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.WaitCondition) <-chan error); ok { + r1 = rf(ctx, containerID, condition) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(<-chan error) + } + } + + return r0, r1 +} + +// CopyFromContainer provides a mock function with given fields: ctx, _a1, srcPath +func (_m *DockerClient) CopyFromContainer(ctx context.Context, _a1 string, srcPath string) (io.ReadCloser, container.PathStat, error) { + ret := _m.Called(ctx, _a1, srcPath) + + if len(ret) == 0 { + panic("no return value specified for CopyFromContainer") + } + + var r0 io.ReadCloser + var r1 container.PathStat + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (io.ReadCloser, container.PathStat, error)); ok { + return rf(ctx, _a1, srcPath) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) io.ReadCloser); ok { + r0 = rf(ctx, _a1, srcPath) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) container.PathStat); ok { + r1 = rf(ctx, _a1, srcPath) + } else { + r1 = ret.Get(1).(container.PathStat) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, string) error); ok { + r2 = rf(ctx, _a1, srcPath) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CopyToContainer provides a mock function with given fields: ctx, _a1, path, content, options +func (_m *DockerClient) CopyToContainer(ctx context.Context, _a1 string, path string, content io.Reader, options container.CopyToContainerOptions) error { + ret := _m.Called(ctx, _a1, path, content, options) + + if len(ret) == 0 { + panic("no return value specified for CopyToContainer") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, io.Reader, container.CopyToContainerOptions) error); ok { + r0 = rf(ctx, _a1, path, content, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ImageInspectWithRaw provides a mock function with given fields: ctx, imageID +func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) { + ret := _m.Called(ctx, imageID) + + if len(ret) == 0 { + panic("no return value specified for ImageInspectWithRaw") + } + + var r0 types.ImageInspect + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ImageInspect, []byte, error)); ok { + return rf(ctx, imageID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ImageInspect); ok { + r0 = rf(ctx, imageID) + } else { + r0 = ret.Get(0).(types.ImageInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) []byte); ok { + r1 = rf(ctx, imageID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, imageID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ImagePull provides a mock function with given fields: ctx, refStr, options +func (_m *DockerClient) ImagePull(ctx context.Context, refStr string, options image.PullOptions) (io.ReadCloser, error) { + ret := _m.Called(ctx, refStr, options) + + if len(ret) == 0 { + panic("no return value specified for ImagePull") + } + + var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) (io.ReadCloser, error)); ok { + return rf(ctx, refStr, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) io.ReadCloser); ok { + r0 = rf(ctx, refStr, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, image.PullOptions) error); ok { + r1 = rf(ctx, refStr, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NetworkCreate provides a mock function with given fields: ctx, name, options +func (_m *DockerClient) NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) { + ret := _m.Called(ctx, name, options) + + if len(ret) == 0 { + panic("no return value specified for NetworkCreate") + } + + var r0 network.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, network.CreateOptions) (network.CreateResponse, error)); ok { + return rf(ctx, name, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, network.CreateOptions) network.CreateResponse); ok { + r0 = rf(ctx, name, options) + } else { + r0 = ret.Get(0).(network.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, network.CreateOptions) error); ok { + r1 = rf(ctx, name, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NetworkInspect provides a mock function with given fields: ctx, networkID, options +func (_m *DockerClient) NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) { + ret := _m.Called(ctx, networkID, options) + + if len(ret) == 0 { + panic("no return value specified for NetworkInspect") + } + + var r0 network.Inspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, network.InspectOptions) (network.Inspect, error)); ok { + return rf(ctx, networkID, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, network.InspectOptions) network.Inspect); ok { + r0 = rf(ctx, networkID, options) + } else { + r0 = ret.Get(0).(network.Inspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, network.InspectOptions) error); ok { + r1 = rf(ctx, networkID, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NetworkRemove provides a mock function with given fields: ctx, networkID +func (_m *DockerClient) NetworkRemove(ctx context.Context, networkID string) error { + ret := _m.Called(ctx, networkID) + + if len(ret) == 0 { + panic("no return value specified for NetworkRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, networkID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Ping provides a mock function with given fields: ctx +func (_m *DockerClient) Ping(ctx context.Context) (types.Ping, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 types.Ping + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (types.Ping, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) types.Ping); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(types.Ping) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeCreate provides a mock function with given fields: ctx, options +func (_m *DockerClient) VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for VolumeCreate") + } + + var r0 volume.Volume + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, volume.CreateOptions) (volume.Volume, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, volume.CreateOptions) volume.Volume); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(volume.Volume) + } + + if rf, ok := ret.Get(1).(func(context.Context, volume.CreateOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeInspect provides a mock function with given fields: ctx, volumeID +func (_m *DockerClient) VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) { + ret := _m.Called(ctx, volumeID) + + if len(ret) == 0 { + panic("no return value specified for VolumeInspect") + } + + var r0 volume.Volume + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (volume.Volume, error)); ok { + return rf(ctx, volumeID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) volume.Volume); ok { + r0 = rf(ctx, volumeID) + } else { + r0 = ret.Get(0).(volume.Volume) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, volumeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeList provides a mock function with given fields: ctx, options +func (_m *DockerClient) VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for VolumeList") + } + + var r0 volume.ListResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, volume.ListOptions) (volume.ListResponse, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, volume.ListOptions) volume.ListResponse); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(volume.ListResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, volume.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeRemove provides a mock function with given fields: ctx, volumeID, force +func (_m *DockerClient) VolumeRemove(ctx context.Context, volumeID string, force bool) error { + ret := _m.Called(ctx, volumeID, force) + + if len(ret) == 0 { + panic("no return value specified for VolumeRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { + r0 = rf(ctx, volumeID, force) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewDockerClient creates a new instance of DockerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDockerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DockerClient { + mock := &DockerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From f3b70e371233b91480097b1a267292adec972938 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 04:46:58 +0200 Subject: [PATCH 36/50] fix bad rebase --- core/provider/digitalocean/docker.go | 0 core/provider/digitalocean/provider.go | 3 +-- core/provider/digitalocean/provider_test.go | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) delete mode 100644 core/provider/digitalocean/docker.go diff --git a/core/provider/digitalocean/docker.go b/core/provider/digitalocean/docker.go deleted file mode 100644 index e69de29b..00000000 diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index d5ff2741..3bd71311 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -338,7 +338,6 @@ func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) e taskState := task.GetState() task.logger = p.logger.With(zap.String("task", taskState.Name)) task.doClient = p.doClient - task.provider = p droplet, err := task.getDroplet(ctx) if err != nil { @@ -351,7 +350,7 @@ func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) e } if p.dockerClients[ip] == nil { - dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return fmt.Errorf("failed to create docker client: %w", err) } diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 3bcca6f5..631f1299 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -248,7 +248,6 @@ func TestSerializeAndRestoreTask(t *testing.T) { assert.NotNil(t, t2State.SSHKeyPair) assert.NotNil(t, t2.doClient) assert.NotNil(t, t2.dockerClient) - assert.NotNil(t, t2.provider) mockDO.AssertExpectations(t) mockDocker.AssertExpectations(t) From f2c49c171dc30725a38ec0c1833cda4ea8ebcdce Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 04:51:21 +0200 Subject: [PATCH 37/50] lint --- core/provider/digitalocean/provider.go | 4 ---- core/provider/docker_client.go | 8 ++++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 3bd71311..46c73200 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -43,10 +43,6 @@ type Provider struct { logger *zap.Logger doClient DoClient - petriTag string - userIPs []string - sshKeyPair *SSHKeyPair - firewallID string dockerClients map[string]provider.DockerClient // map of droplet ip address to docker clients } diff --git a/core/provider/docker_client.go b/core/provider/docker_client.go index 7804a97d..d71b0682 100644 --- a/core/provider/docker_client.go +++ b/core/provider/docker_client.go @@ -113,15 +113,15 @@ func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID return d.client.ContainerInspect(ctx, containerID) } -func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, config types.ExecConfig) (types.IDResponse, error) { +func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) { return d.client.ContainerExecCreate(ctx, container, config) } -func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config types.ExecStartCheck) (types.HijackedResponse, error) { +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) { return d.client.ContainerExecAttach(ctx, execID, config) } -func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (types.ContainerExecInspect, error) { +func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { return d.client.ContainerExecInspect(ctx, execID) } @@ -141,7 +141,7 @@ func (d *defaultDockerClient) CopyToContainer(ctx context.Context, container, pa return d.client.CopyToContainer(ctx, container, path, content, options) } -func (d *defaultDockerClient) CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, types.ContainerPathStat, error) { +func (d *defaultDockerClient) CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, container.PathStat, error) { return d.client.CopyFromContainer(ctx, container, srcPath) } From 763152b6323fc6de6952d643913f2cc27b2f48a3 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Fri, 17 Jan 2025 19:45:11 +0200 Subject: [PATCH 38/50] fix tests --- core/provider/digitalocean/provider.go | 6 +++--- core/provider/digitalocean/provider_test.go | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 46c73200..8547286b 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -239,7 +239,7 @@ func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { return bz, err } -func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]DockerClient) (*Provider, error) { +func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]provider.DockerClient) (*Provider, error) { if doClient == nil && token == "" { return nil, errors.New("a valid token or digital ocean client must be passed when restoring the provider") } @@ -251,7 +251,7 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } if dockerClients == nil { - dockerClients = make(map[string]DockerClient) + dockerClients = make(map[string]provider.DockerClient) } digitalOceanProvider := &Provider{ @@ -283,7 +283,7 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } if digitalOceanProvider.dockerClients[ip] == nil { - dockerClient, err := NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, fmt.Errorf("failed to create docker client: %w", err) } diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 631f1299..47baccda 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -3,13 +3,14 @@ package digitalocean import ( "context" "fmt" - "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" "io" "strings" "sync" "testing" "time" + "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/digitalocean/godo" "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" @@ -54,6 +55,7 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC }, NetworkMode: container.NetworkMode("host"), }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), "test-container").Return(container.CreateResponse{ID: "test-container"}, nil) + mockDocker.On("Close").Return(nil) mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) @@ -131,6 +133,9 @@ func TestCreateTask_ValidTask(t *testing.T) { assert.NoError(t, err) assert.Equal(t, task.GetDefinition(), taskDef) assert.NotNil(t, task) + + err = task.Destroy(ctx) + assert.NoError(t, err) } func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { @@ -249,6 +254,9 @@ func TestSerializeAndRestoreTask(t *testing.T) { assert.NotNil(t, t2.doClient) assert.NotNil(t, t2.dockerClient) + err = t2.Destroy(ctx) + assert.NoError(t, err) + mockDO.AssertExpectations(t) mockDocker.AssertExpectations(t) } @@ -454,14 +462,14 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { func TestProviderSerialization(t *testing.T) { ctx := context.Background() mockDO := mocks.NewDoClient(t) - mockDocker := mocks.NewDockerClient(t) + mockDocker := dockerMocks.NewDockerClient(t) mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "petri-droplet-test"}, nil) mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]DockerClient{ + mockDockerClients := map[string]provider.DockerClient{ "10.0.0.1": mockDocker, } @@ -522,10 +530,10 @@ func TestProviderSerialization(t *testing.T) { mockDO2 := mocks.NewDoClient(t) mockDO2.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() - mockDocker2 := mocks.NewDockerClient(t) + mockDocker2 := dockerMocks.NewDockerClient(t) mockDocker2.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() - mockDockerClients2 := map[string]DockerClient{ + mockDockerClients2 := map[string]provider.DockerClient{ "10.0.0.1": mockDocker2, } From 93dcdeb06d1ed9c089b5b81b2413636536fc3737 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Sun, 19 Jan 2025 06:06:03 +0200 Subject: [PATCH 39/50] use network name --- core/provider/docker/task.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 3b9d797e..e927d03b 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -307,7 +307,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string } hostConfig := &container.HostConfig{ - NetworkMode: container.NetworkMode("host"), + NetworkMode: container.NetworkMode(t.state.NetworkName), Mounts: mounts, } From 1e13c94b4c5dacf497eb599aef146a674c8f0a21 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Sun, 19 Jan 2025 07:03:41 +0200 Subject: [PATCH 40/50] add back AllocatedIPs to state --- core/provider/docker/provider.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index bc2793e2..17dec0ea 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -26,10 +26,11 @@ type ProviderState struct { Name string `json:"name"` - NetworkID string `json:"network_id"` - NetworkName string `json:"network_name"` - NetworkCIDR string `json:"network_cidr"` - NetworkGateway string `json:"network_gateway"` + NetworkID string `json:"network_id"` + NetworkName string `json:"network_name"` + NetworkCIDR string `json:"network_cidr"` + NetworkGateway string `json:"network_gateway"` + AllocatedIPs []string `json:"allocated_ips"` BuilderImageName string `json:"builder_image_name"` } From c6668a1a0e90e6b9abf064885daa587d64d0d973 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Mon, 20 Jan 2025 20:03:41 +0200 Subject: [PATCH 41/50] bad rebase --- cosmos/node/genesis.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cosmos/node/genesis.go b/cosmos/node/genesis.go index 7adf3f34..d6de7aee 100644 --- a/cosmos/node/genesis.go +++ b/cosmos/node/genesis.go @@ -71,7 +71,7 @@ func (n *Node) AddGenesisAccount(ctx context.Context, address string, genesisAmo command = append(command, "add-genesis-account", address, amount) command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) + stdout, stderr, exitCode, err := n.RunCommand(ctx, command) n.logger.Debug("add-genesis-account", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -101,7 +101,7 @@ func (n *Node) GenerateGenTx(ctx context.Context, genesisSelfDelegation types.Co command = n.BinCommand(command...) - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, command) + stdout, stderr, exitCode, err := n.RunCommand(ctx, command) n.logger.Debug("gentx", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { @@ -127,7 +127,7 @@ func (n *Node) CollectGenTxs(ctx context.Context) error { command = append(command, "collect-gentxs") - stdout, stderr, exitCode, err := n.RunCommandWhileStopped(ctx, n.BinCommand(command...)) + stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand(command...)) n.logger.Debug("collect-gentxs", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { From d70948ea254b904d621031f0a9349b12880a91f3 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 01:52:08 +0200 Subject: [PATCH 42/50] refactor image pull logic --- core/provider/digitalocean/provider.go | 5 ++- core/provider/digitalocean/provider_test.go | 10 +++--- core/provider/digitalocean/task.go | 2 +- core/provider/docker/provider.go | 6 ++-- core/provider/docker/provider_test.go | 12 +++---- core/provider/docker/volume.go | 11 +++--- core/provider/docker_client.go | 40 ++++++++++----------- core/provider/mocks/docker_client_mock.go | 30 ++++++---------- 8 files changed, 51 insertions(+), 65 deletions(-) diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 8547286b..636574ca 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "strconv" + "github.com/docker/docker/api/types/image" "strings" "sync" "time" @@ -161,8 +162,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) if err != nil { p.logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) - err = pullImage(ctx, dockerClient, p.logger, definition.Image.Image) - if err != nil { + if err = dockerClient.ImagePull(ctx, p.logger, definition.Image.Image, image.PullOptions{}); err != nil { return nil, err } } @@ -223,7 +223,6 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin return &Task{ state: taskState, removeTask: p.removeTask, - sshKeyPair: p.state.SSHKeyPair, logger: p.logger.With(zap.String("task", definition.Name)), doClient: p.doClient, dockerClient: dockerClient, diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 47baccda..3dfa99e5 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -3,8 +3,6 @@ package digitalocean import ( "context" "fmt" - "io" - "strings" "sync" "testing" "time" @@ -35,7 +33,7 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC mockDocker.On("Ping", ctx).Return(types.Ping{}, nil) mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) - mockDocker.On("ImagePull", ctx, "ubuntu:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "ubuntu:latest", image.PullOptions{}).Return(nil) mockDocker.On("ContainerCreate", ctx, &container.Config{ Image: "ubuntu:latest", Entrypoint: []string{"/bin/bash"}, @@ -277,7 +275,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() mockDocker.On("ImageInspectWithRaw", ctx, "nginx:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() - mockDocker.On("ImagePull", ctx, "nginx:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil).Once() + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "nginx:latest", image.PullOptions{}).Return(nil).Once() mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { return config.Image == "nginx:latest" }), mock.Anything, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() @@ -492,9 +490,9 @@ func TestProviderSerialization(t *testing.T) { mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil) mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() - mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) - mockDocker.On("ImagePull", ctx, "ubuntu:latest", image.PullOptions{}).Return(io.NopCloser(strings.NewReader("")), nil) + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "ubuntu:latest", image.PullOptions{}).Return(nil) mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { return config.Image == "ubuntu:latest" && config.Hostname == "test-task" && diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index f2f84878..2fc510f7 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -289,7 +289,7 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in return t.runCommand(ctx, cmd) } -func waitForExec(ctx context.Context, dockerClient DockerClient, execID string) (int, error) { +func waitForExec(ctx context.Context, dockerClient provider.DockerClient, execID string) (int, error) { lastExitCode := 0 ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index 17dec0ea..127d391a 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -17,8 +17,6 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" - - "github.com/docker/docker/client" ) type ProviderState struct { @@ -48,7 +46,7 @@ type Provider struct { var _ provider.ProviderI = (*Provider)(nil) func CreateProvider(ctx context.Context, logger *zap.Logger, providerName string) (*Provider, error) { - dockerClient, err := client.NewClientWithOpts() + dockerClient, err := provider.NewDockerClient("") if err != nil { return nil, err } @@ -175,7 +173,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin logger := p.logger.Named("docker_provider") - if err := provider.PullImage(ctx, p.dockerClient, logger, definition.Image.Image); err != nil { + if err := p.dockerClient.ImagePull(ctx, p.logger, p.GetState().BuilderImageName, image.PullOptions{}); err != nil { return nil, err } diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index d8207737..cbd9ceeb 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -3,15 +3,16 @@ package docker_test import ( "context" "fmt" + "sync" + "testing" + "time" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" gonanoid "github.com/matoous/go-nanoid/v2" "github.com/skip-mev/petri/core/v3/provider/docker" "go.uber.org/zap/zaptest" - "sync" - "testing" - "time" "github.com/skip-mev/petri/core/v3/provider" "github.com/stretchr/testify/assert" @@ -230,12 +231,11 @@ func TestConcurrentTaskCreation(t *testing.T) { for task := range tasks { taskState := task.GetState() - client, _ := client.NewClientWithOpts() - containerJSON, err := client.ContainerInspect(ctx, taskState.Id) + dockerClient, _ := provider.NewDockerClient("") + containerJSON, err := dockerClient.ContainerInspect(ctx, taskState.Id) require.NoError(t, err) ip := containerJSON.NetworkSettings.Networks[providerState.NetworkName].IPAddress - fmt.Println(ip) assert.False(t, ips[ip], "Duplicate IP found: %s", ip) ips[ip] = true } diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 29d40f82..2e1ee0c0 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "fmt" + "github.com/docker/docker/api/types/image" "io" "os" "path" @@ -71,7 +72,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } @@ -196,7 +197,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } @@ -331,7 +332,7 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { containerName := fmt.Sprintf("petri-getfile-%d", time.Now().UnixNano()) - if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return nil, err } @@ -422,7 +423,7 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error logger.Debug("creating getdir container") - if err := provider.PullImage(ctx, t.dockerClient, t.logger, state.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } @@ -509,7 +510,7 @@ func (p *Provider) SetVolumeOwner(ctx context.Context, volumeName, uid, gid stri containerName := fmt.Sprintf("petri-setowner-%d", time.Now().UnixNano()) - if err := provider.PullImage(ctx, p.dockerClient, p.logger, p.GetState().BuilderImageName); err != nil { + if err := p.dockerClient.ImagePull(ctx, p.logger, p.GetState().BuilderImageName, image.PullOptions{}); err != nil { return err } diff --git a/core/provider/docker_client.go b/core/provider/docker_client.go index d71b0682..4abd29c4 100644 --- a/core/provider/docker_client.go +++ b/core/provider/docker_client.go @@ -3,6 +3,7 @@ package provider import ( "context" "fmt" + "go.uber.org/zap" "io" "github.com/docker/docker/api/types" @@ -12,7 +13,6 @@ import ( "github.com/docker/docker/api/types/volume" dockerclient "github.com/docker/docker/client" specs "github.com/opencontainers/image-spec/specs-go/v1" - "go.uber.org/zap" ) // DockerClient is a unified interface for interacting with Docker @@ -39,7 +39,7 @@ type DockerClient interface { // Image Operations ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) - ImagePull(ctx context.Context, refStr string, options image.PullOptions) (io.ReadCloser, error) + ImagePull(ctx context.Context, logger *zap.Logger, refStr string, options image.PullOptions) error // Volume Operations VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) @@ -89,8 +89,24 @@ func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image str return d.client.ImageInspectWithRaw(ctx, image) } -func (d *defaultDockerClient) ImagePull(ctx context.Context, ref string, options image.PullOptions) (io.ReadCloser, error) { - return d.client.ImagePull(ctx, ref, options) +func (d *defaultDockerClient) ImagePull(ctx context.Context, logger *zap.Logger, ref string, options image.PullOptions) error { + _, _, err := d.client.ImageInspectWithRaw(ctx, ref) + if err != nil { + logger.Info("pulling image", zap.String("image", ref)) + resp, err := d.client.ImagePull(ctx, ref, options) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + + defer resp.Close() + // throw away the image pull stdout response + _, err = io.Copy(io.Discard, resp) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + return nil + } + return nil } func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { @@ -176,19 +192,3 @@ func (d *defaultDockerClient) NetworkRemove(ctx context.Context, networkID strin func (d *defaultDockerClient) Close() error { return d.client.Close() } - -func PullImage(ctx context.Context, dockerClient DockerClient, logger *zap.Logger, img string) error { - logger.Info("pulling image", zap.String("image", img)) - resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) - if err != nil { - return fmt.Errorf("failed to pull docker image: %w", err) - } - - defer resp.Close() - // throw away the image pull stdout response - _, err = io.Copy(io.Discard, resp) - if err != nil { - return fmt.Errorf("failed to pull docker image: %w", err) - } - return nil -} diff --git a/core/provider/mocks/docker_client_mock.go b/core/provider/mocks/docker_client_mock.go index 2aab7612..f4e3fe70 100644 --- a/core/provider/mocks/docker_client_mock.go +++ b/core/provider/mocks/docker_client_mock.go @@ -20,6 +20,8 @@ import ( v1 "github.com/opencontainers/image-spec/specs-go/v1" volume "github.com/docker/docker/api/types/volume" + + zap "go.uber.org/zap" ) // DockerClient is an autogenerated mock type for the DockerClient type @@ -423,34 +425,22 @@ func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, imageID string) return r0, r1, r2 } -// ImagePull provides a mock function with given fields: ctx, refStr, options -func (_m *DockerClient) ImagePull(ctx context.Context, refStr string, options image.PullOptions) (io.ReadCloser, error) { - ret := _m.Called(ctx, refStr, options) +// ImagePull provides a mock function with given fields: ctx, logger, refStr, options +func (_m *DockerClient) ImagePull(ctx context.Context, logger *zap.Logger, refStr string, options image.PullOptions) error { + ret := _m.Called(ctx, logger, refStr, options) if len(ret) == 0 { panic("no return value specified for ImagePull") } - var r0 io.ReadCloser - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) (io.ReadCloser, error)); ok { - return rf(ctx, refStr, options) - } - if rf, ok := ret.Get(0).(func(context.Context, string, image.PullOptions) io.ReadCloser); ok { - r0 = rf(ctx, refStr, options) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(io.ReadCloser) - } - } - - if rf, ok := ret.Get(1).(func(context.Context, string, image.PullOptions) error); ok { - r1 = rf(ctx, refStr, options) + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *zap.Logger, string, image.PullOptions) error); ok { + r0 = rf(ctx, logger, refStr, options) } else { - r1 = ret.Error(1) + r0 = ret.Error(0) } - return r0, r1 + return r0 } // NetworkCreate provides a mock function with given fields: ctx, name, options From 9a5568b608673d3ded15afde672fc6b026be5263 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 02:14:32 +0200 Subject: [PATCH 43/50] use getState everywhere --- core/provider/docker/network.go | 25 ++++++----- core/provider/docker/provider.go | 17 +++---- core/provider/docker/task.go | 76 ++++++++++++++++++-------------- core/provider/docker/volume.go | 2 +- 4 files changed, 66 insertions(+), 54 deletions(-) diff --git a/core/provider/docker/network.go b/core/provider/docker/network.go index f18d42cf..2616e3b6 100644 --- a/core/provider/docker/network.go +++ b/core/provider/docker/network.go @@ -15,11 +15,13 @@ import ( const providerLabelName = "petri-provider" func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { - p.logger.Info("creating network", zap.String("name", p.state.NetworkName)) + state := p.GetState() + + p.logger.Info("creating network", zap.String("name", state.NetworkName)) subnet1 := rand.Intn(255) subnet2 := rand.Intn(255) - networkResponse, err := p.dockerClient.NetworkCreate(ctx, p.state.NetworkName, network.CreateOptions{ + networkResponse, err := p.dockerClient.NetworkCreate(ctx, state.NetworkName, network.CreateOptions{ Scope: "local", Driver: "bridge", Options: map[string]string{ // https://docs.docker.com/engine/reference/commandline/network_create/#bridge-driver-options @@ -32,7 +34,7 @@ func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { Attachable: false, Ingress: false, Labels: map[string]string{ - providerLabelName: p.state.Name, + providerLabelName: state.Name, }, IPAM: &network.IPAM{ Driver: "default", @@ -59,33 +61,34 @@ func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { // ensureNetwork checks if the network exists and has the expected configuration. func (p *Provider) ensureNetwork(ctx context.Context) error { - network, err := p.dockerClient.NetworkInspect(ctx, p.state.NetworkID, network.InspectOptions{}) + state := p.GetState() + network, err := p.dockerClient.NetworkInspect(ctx, state.NetworkID, network.InspectOptions{}) if err != nil { return err } - if network.ID != p.state.NetworkID { - return fmt.Errorf("network ID mismatch: %s != %s", network.ID, p.state.NetworkID) + if network.ID != state.NetworkID { + return fmt.Errorf("network ID mismatch: %s != %s", network.ID, state.NetworkID) } - if network.Name != p.state.NetworkName { - return fmt.Errorf("network name mismatch: %s != %s", network.Name, p.state.NetworkName) + if network.Name != state.NetworkName { + return fmt.Errorf("network name mismatch: %s != %s", network.Name, state.NetworkName) } if len(network.IPAM.Config) != 1 { return fmt.Errorf("unexpected number of IPAM configs: %d", len(network.IPAM.Config)) } - if network.IPAM.Config[0].Subnet != p.state.NetworkCIDR { - return fmt.Errorf("network CIDR mismatch: %s != %s", network.IPAM.Config[0].Subnet, p.state.NetworkCIDR) + if network.IPAM.Config[0].Subnet != state.NetworkCIDR { + return fmt.Errorf("network CIDR mismatch: %s != %s", network.IPAM.Config[0].Subnet, state.NetworkCIDR) } return nil } func (p *Provider) destroyNetwork(ctx context.Context) error { - return p.dockerClient.NetworkRemove(ctx, p.state.NetworkID) + return p.dockerClient.NetworkRemove(ctx, p.GetState().NetworkID) } // openListenerOnFreePort opens the next free port diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index 127d391a..53913117 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -164,16 +164,17 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin if err := definition.ValidateBasic(); err != nil { return &Task{}, fmt.Errorf("failed to validate task definition: %w", err) } + state := p.GetState() taskState := &TaskState{ Name: definition.Name, Definition: definition, - BuilderImageName: p.state.BuilderImageName, + BuilderImageName: state.BuilderImageName, } logger := p.logger.Named("docker_provider") - if err := p.dockerClient.ImagePull(ctx, p.logger, p.GetState().BuilderImageName, image.PullOptions{}); err != nil { + if err := p.dockerClient.ImagePull(ctx, p.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return nil, err } @@ -236,17 +237,17 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin Tty: false, Hostname: definition.Name, Labels: map[string]string{ - providerLabelName: p.state.Name, + providerLabelName: state.Name, }, Env: convertEnvMapToList(definition.Environment), ExposedPorts: portSet, }, &container.HostConfig{ Mounts: mounts, PortBindings: portBindings, - NetworkMode: container.NetworkMode(p.state.NetworkName), + NetworkMode: container.NetworkMode(state.NetworkName), }, &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{ - p.state.NetworkName: { + state.NetworkName: { IPAMConfig: &network.EndpointIPAMConfig{ IPv4Address: ip, }, @@ -260,8 +261,8 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin taskState.Id = createdContainer.ID taskState.Status = provider.TASK_STOPPED - taskState.NetworkName = p.state.NetworkName - taskState.ProviderName = p.state.Name + taskState.NetworkName = state.NetworkName + taskState.ProviderName = state.Name taskState.IpAddress = ip p.stateMu.Lock() @@ -336,7 +337,7 @@ func (p *Provider) removeTask(_ context.Context, taskID string) error { func (p *Provider) Teardown(ctx context.Context) error { p.logger.Info("tearing down Docker provider") - for _, task := range p.state.TaskStates { + for _, task := range p.GetState().TaskStates { if err := p.dockerClient.ContainerRemove(ctx, task.Id, container.RemoveOptions{ Force: true, }); err != nil { diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index e927d03b..027dbd05 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -17,12 +17,12 @@ import ( ) type TaskState struct { - Id string `json:"id"` - Name string `json:"name"` - Volume *VolumeState `json:"volumes"` - Definition provider.TaskDefinition `json:"definition"` - Status provider.TaskStatus `json:"status"` - IpAddress string `json:"ip_address"` + Id string `json:"id"` + Name string `json:"name"` + Volume *VolumeState `json:"volumes"` + Definition provider.TaskDefinition `json:"definition"` + Status provider.TaskStatus `json:"status"` + IpAddress string `json:"ip_address"` BuilderImageName string `json:"builder_image_name"` ProviderName string `json:"provider_name"` NetworkName string `json:"network_name"` @@ -44,9 +44,10 @@ type Task struct { var _ provider.TaskI = (*Task)(nil) func (t *Task) Start(ctx context.Context) error { - t.logger.Info("starting task", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Info("starting task", zap.String("id", state.Id)) - err := t.dockerClient.ContainerStart(ctx, t.state.Id, container.StartOptions{}) + err := t.dockerClient.ContainerStart(ctx, state.Id, container.StartOptions{}) if err != nil { return err } @@ -64,9 +65,10 @@ func (t *Task) Start(ctx context.Context) error { } func (t *Task) Stop(ctx context.Context) error { - t.logger.Info("stopping task", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Info("stopping task", zap.String("id", state.Id)) - err := t.dockerClient.ContainerStop(ctx, t.state.Id, container.StopOptions{}) + err := t.dockerClient.ContainerStop(ctx, state.Id, container.StopOptions{}) if err != nil { return err } @@ -84,9 +86,10 @@ func (t *Task) Stop(ctx context.Context) error { } func (t *Task) Destroy(ctx context.Context) error { - t.logger.Info("destroying task", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Info("destroying task", zap.String("id", state.Id)) - err := t.dockerClient.ContainerRemove(ctx, t.state.Id, container.RemoveOptions{ + err := t.dockerClient.ContainerRemove(ctx, state.Id, container.RemoveOptions{ Force: true, RemoveVolumes: true, }) @@ -95,7 +98,7 @@ func (t *Task) Destroy(ctx context.Context) error { return err } - if err := t.removeTask(ctx, t.state.Id); err != nil { + if err := t.removeTask(ctx, state.Id); err != nil { return err } @@ -103,9 +106,10 @@ func (t *Task) Destroy(ctx context.Context) error { } func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { - t.logger.Debug("getting external address", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Debug("getting external address", zap.String("id", state.Id)) - dockerContainer, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return "", fmt.Errorf("failed to inspect container: %w", err) } @@ -119,14 +123,15 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err } func (t *Task) GetIP(ctx context.Context) (string, error) { - t.logger.Debug("getting IP", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Debug("getting IP", zap.String("id", state.Id)) - dockerContainer, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return "", err } - ip := dockerContainer.NetworkSettings.Networks[t.state.NetworkName].IPAMConfig.IPv4Address + ip := dockerContainer.NetworkSettings.Networks[state.NetworkName].IPAMConfig.IPv4Address return ip, nil } @@ -150,7 +155,7 @@ func (t *Task) WaitForStatus(ctx context.Context, interval time.Duration, desire } func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { - containerJSON, err := t.dockerClient.ContainerInspect(ctx, t.state.Id) + containerJSON, err := t.dockerClient.ContainerInspect(ctx, t.GetState().Id) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -193,15 +198,16 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in } func (t *Task) runCommand(ctx context.Context, cmd []string) (string, string, int, error) { - t.logger.Debug("running command", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + state := t.GetState() + t.logger.Debug("running command", zap.String("id", state.Id), zap.Strings("command", cmd)) - exec, err := t.dockerClient.ContainerExecCreate(ctx, t.state.Id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, state.Id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: cmd, }) if err != nil { - if buf, err := t.dockerClient.ContainerLogs(ctx, t.state.Id, container.LogsOptions{ + if buf, err := t.dockerClient.ContainerLogs(ctx, state.Id, container.LogsOptions{ ShowStdout: true, ShowStderr: true, }); err == nil { @@ -248,7 +254,7 @@ loop: } if err != nil { - t.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.state.Id)) + t.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.GetState().Id)) return "", "", lastExitCode, err } @@ -262,12 +268,13 @@ loop: } func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { + state := t.GetState() definition := t.GetState().Definition if err := definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - t.logger.Debug("running command while stopped", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + t.logger.Debug("running command while stopped", zap.String("id", state.Id), zap.Strings("command", cmd)) status, err := t.GetStatus(ctx) if err != nil { @@ -290,24 +297,24 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string Tty: false, Hostname: definition.Name, Labels: map[string]string{ - providerLabelName: t.state.ProviderName, + providerLabelName: state.ProviderName, }, Env: convertEnvMapToList(definition.Environment), } var mounts []mount.Mount - if t.state.Volume != nil { + if state.Volume != nil { mounts = []mount.Mount{ { Type: mount.TypeVolume, - Source: t.state.Volume.Name, + Source: state.Volume.Name, Target: definition.DataDir, }, } } hostConfig := &container.HostConfig{ - NetworkMode: container.NetworkMode(t.state.NetworkName), + NetworkMode: container.NetworkMode(state.NetworkName), Mounts: mounts, } @@ -322,8 +329,8 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string Name: definition.Name, Definition: definition, Status: provider.TASK_STOPPED, - ProviderName: t.state.ProviderName, - NetworkName: t.state.NetworkName, + ProviderName: state.ProviderName, + NetworkName: state.NetworkName, }, logger: t.logger.With(zap.String("temp_task", definition.Name)), dockerClient: t.dockerClient, @@ -381,17 +388,18 @@ func (t *Task) ensureTask(ctx context.Context) error { } func (t *Task) ensureVolume(ctx context.Context) error { - if t.state.Volume == nil { + state := t.GetState() + if state.Volume == nil { return nil } - volume, err := t.dockerClient.VolumeInspect(ctx, t.state.Volume.Name) + volume, err := t.dockerClient.VolumeInspect(ctx, state.Volume.Name) if err != nil { return fmt.Errorf("failed to inspect volume: %w", err) } - if volume.Name != t.state.Volume.Name { - return fmt.Errorf("volume name mismatch, expected: %s, got: %s", t.state.Volume.Name, volume.Name) + if volume.Name != state.Volume.Name { + return fmt.Errorf("volume name mismatch, expected: %s, got: %s", state.Volume.Name, volume.Name) } return nil diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 2e1ee0c0..04d236a1 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -94,7 +94,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string }, Labels: map[string]string{ - providerLabelName: t.state.ProviderName, + providerLabelName: t.GetState().ProviderName, }, // Use root user to avoid permission issues when reading files from the volume. From 6b600c7d1e0b47d0d7cbcb8f37191897a46a6b91 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 02:25:44 +0200 Subject: [PATCH 44/50] remove provider name field from task state --- core/provider/docker/provider.go | 10 ++++------ core/provider/docker/task.go | 17 ++++++----------- core/provider/docker/volume.go | 17 ----------------- 3 files changed, 10 insertions(+), 34 deletions(-) diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index 53913117..c6e2a515 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -24,11 +24,10 @@ type ProviderState struct { Name string `json:"name"` - NetworkID string `json:"network_id"` - NetworkName string `json:"network_name"` - NetworkCIDR string `json:"network_cidr"` - NetworkGateway string `json:"network_gateway"` - AllocatedIPs []string `json:"allocated_ips"` + NetworkID string `json:"network_id"` + NetworkName string `json:"network_name"` + NetworkCIDR string `json:"network_cidr"` + NetworkGateway string `json:"network_gateway"` BuilderImageName string `json:"builder_image_name"` } @@ -262,7 +261,6 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin taskState.Id = createdContainer.ID taskState.Status = provider.TASK_STOPPED taskState.NetworkName = state.NetworkName - taskState.ProviderName = state.Name taskState.IpAddress = ip p.stateMu.Lock() diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 027dbd05..76c4e3cb 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -24,7 +24,6 @@ type TaskState struct { Status provider.TaskStatus `json:"status"` IpAddress string `json:"ip_address"` BuilderImageName string `json:"builder_image_name"` - ProviderName string `json:"provider_name"` NetworkName string `json:"network_name"` } @@ -296,10 +295,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string Cmd: definition.Command, Tty: false, Hostname: definition.Name, - Labels: map[string]string{ - providerLabelName: state.ProviderName, - }, - Env: convertEnvMapToList(definition.Environment), + Env: convertEnvMapToList(definition.Environment), } var mounts []mount.Mount @@ -325,12 +321,11 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string tempTask := &Task{ state: &TaskState{ - Id: resp.ID, - Name: definition.Name, - Definition: definition, - Status: provider.TASK_STOPPED, - ProviderName: state.ProviderName, - NetworkName: state.NetworkName, + Id: resp.ID, + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + NetworkName: state.NetworkName, }, logger: t.logger.With(zap.String("temp_task", definition.Name)), dockerClient: t.dockerClient, diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 04d236a1..68b13f97 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -93,10 +93,6 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string mountPath, }, - Labels: map[string]string{ - providerLabelName: t.GetState().ProviderName, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0:0", }, @@ -218,10 +214,6 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er mountPath, }, - Labels: map[string]string{ - providerLabelName: state.ProviderName, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0:0", }, @@ -343,10 +335,6 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { &container.Config{ Image: state.BuilderImageName, - Labels: map[string]string{ - providerLabelName: state.ProviderName, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0", }, @@ -431,11 +419,6 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error ctx, &container.Config{ Image: state.BuilderImageName, - - Labels: map[string]string{ - providerLabelName: state.ProviderName, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0", }, From 29f1a588decf41c06436dc8ff0c1eae51a4022ef Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 02:57:50 +0200 Subject: [PATCH 45/50] add test for remove task --- core/provider/docker/task.go | 1 - core/provider/docker/task_test.go | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 76c4e3cb..a55a7d22 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -329,7 +329,6 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string }, logger: t.logger.With(zap.String("temp_task", definition.Name)), dockerClient: t.dockerClient, - removeTask: t.removeTask, } err = tempTask.Start(ctx) diff --git a/core/provider/docker/task_test.go b/core/provider/docker/task_test.go index a222bafd..237d00e9 100644 --- a/core/provider/docker/task_test.go +++ b/core/provider/docker/task_test.go @@ -54,8 +54,10 @@ func TestTaskLifecycle(t *testing.T) { err = task.Stop(ctx) require.NoError(t, err) + require.Equal(t, 1, len(p.GetState().TaskStates)) err = task.Destroy(ctx) require.NoError(t, err) + require.Equal(t, 0, len(p.GetState().TaskStates)) dockerTask, ok := task.(*docker.Task) require.True(t, ok) From d3a70f157dfc9318bd3db3909bb00feef0528c8b Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 14:23:41 +0200 Subject: [PATCH 46/50] move docker client to provider/clients --- core/provider/{ => clients}/docker_client.go | 2 +- core/provider/digitalocean/droplet.go | 3 +- core/provider/digitalocean/provider.go | 36 +++++++++----------- core/provider/digitalocean/provider_test.go | 11 +++--- core/provider/digitalocean/task.go | 7 ++-- core/provider/docker/provider.go | 6 ++-- core/provider/docker/provider_test.go | 3 +- core/provider/docker/task.go | 5 ++- 8 files changed, 39 insertions(+), 34 deletions(-) rename core/provider/{ => clients}/docker_client.go (99%) diff --git a/core/provider/docker_client.go b/core/provider/clients/docker_client.go similarity index 99% rename from core/provider/docker_client.go rename to core/provider/clients/docker_client.go index 4abd29c4..44c3394a 100644 --- a/core/provider/docker_client.go +++ b/core/provider/clients/docker_client.go @@ -1,4 +1,4 @@ -package provider +package clients import ( "context" diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 36ff76d2..7c63a9b4 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -3,6 +3,7 @@ package digitalocean import ( "context" "fmt" + "github.com/skip-mev/petri/core/v2/provider/clients" "time" "github.com/pkg/errors" @@ -75,7 +76,7 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe } if p.dockerClients[ip] == nil { - dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { p.logger.Error("failed to create docker client", zap.Error(err)) return false, err diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 636574ca..2ee76bee 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -7,6 +7,7 @@ import ( "fmt" "strconv" "github.com/docker/docker/api/types/image" + "github.com/skip-mev/petri/core/v2/provider/clients" "strings" "sync" "time" @@ -44,7 +45,7 @@ type Provider struct { logger *zap.Logger doClient DoClient - dockerClients map[string]provider.DockerClient // map of droplet ip address to docker clients + dockerClients map[string]clients.DockerClient // map of droplet ip address to docker clients } // NewProvider creates a provider that implements the Provider interface for DigitalOcean. @@ -56,7 +57,7 @@ func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, t // NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. // This is primarily used for testing. -func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]provider.DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { +func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]clients.DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { var err error if sshKeyPair == nil { sshKeyPair, err = MakeSSHKeyPair() @@ -73,7 +74,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName userIPs = append(userIPs, additionalUserIPS...) if dockerClients == nil { - dockerClients = make(map[string]provider.DockerClient) + dockerClients = make(map[string]clients.DockerClient) } petriTag := fmt.Sprintf("petri-droplet-%s", util.RandomString(5)) @@ -81,18 +82,15 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName logger: logger.Named("digitalocean_provider"), doClient: doClient, dockerClients: dockerClients, + state: &ProviderState{ + TaskStates: make(map[string]*TaskState), + UserIPs: userIPs, + Name: providerName, + SSHKeyPair: sshKeyPair, + PetriTag: petriTag, + }, } - pState := &ProviderState{ - TaskStates: make(map[string]*TaskState), - UserIPs: userIPs, - Name: providerName, - SSHKeyPair: sshKeyPair, - PetriTag: petriTag, - } - - digitalOceanProvider.state = pState - _, err = digitalOceanProvider.createTag(ctx, petriTag) if err != nil { return nil, err @@ -103,7 +101,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName return nil, fmt.Errorf("failed to create firewall: %w", err) } - pState.FirewallID = firewall.ID + digitalOceanProvider.state.FirewallID = firewall.ID //TODO(Zygimantass): TOCTOU issue if key, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { @@ -153,7 +151,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin dockerClient := p.dockerClients[ip] if dockerClient == nil { - dockerClient, err = provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err = clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, err } @@ -238,7 +236,7 @@ func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { return bz, err } -func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]provider.DockerClient) (*Provider, error) { +func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]clients.DockerClient) (*Provider, error) { if doClient == nil && token == "" { return nil, errors.New("a valid token or digital ocean client must be passed when restoring the provider") } @@ -250,7 +248,7 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } if dockerClients == nil { - dockerClients = make(map[string]provider.DockerClient) + dockerClients = make(map[string]clients.DockerClient) } digitalOceanProvider := &Provider{ @@ -282,7 +280,7 @@ func RestoreProvider(ctx context.Context, token string, state []byte, doClient D } if digitalOceanProvider.dockerClients[ip] == nil { - dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return nil, fmt.Errorf("failed to create docker client: %w", err) } @@ -345,7 +343,7 @@ func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) e } if p.dockerClients[ip] == nil { - dockerClient, err := provider.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) if err != nil { return fmt.Errorf("failed to create docker client: %w", err) } diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 3dfa99e5..9e6f72b9 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -3,6 +3,7 @@ package digitalocean import ( "context" "fmt" + "github.com/skip-mev/petri/core/v2/provider/clients" "sync" "testing" "time" @@ -60,7 +61,7 @@ func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoC mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]provider.DockerClient{ + mockDockerClients := map[string]clients.DockerClient{ "10.0.0.1": mockDocker, } @@ -146,7 +147,7 @@ func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]provider.DockerClient{ + mockDockerClients := map[string]clients.DockerClient{ "10.0.0.1": mockDocker, } @@ -265,7 +266,7 @@ func TestConcurrentTaskCreationAndCleanup(t *testing.T) { defer cancel() logger, _ := zap.NewDevelopment() - mockDockerClients := make(map[string]provider.DockerClient) + mockDockerClients := make(map[string]clients.DockerClient) mockDO := mocks.NewDoClient(t) for i := 0; i < 10; i++ { @@ -467,7 +468,7 @@ func TestProviderSerialization(t *testing.T) { mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) - mockDockerClients := map[string]provider.DockerClient{ + mockDockerClients := map[string]clients.DockerClient{ "10.0.0.1": mockDocker, } @@ -531,7 +532,7 @@ func TestProviderSerialization(t *testing.T) { mockDocker2 := dockerMocks.NewDockerClient(t) mockDocker2.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() - mockDockerClients2 := map[string]provider.DockerClient{ + mockDockerClients2 := map[string]clients.DockerClient{ "10.0.0.1": mockDocker2, } diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 2fc510f7..79b01eb4 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "fmt" + "github.com/skip-mev/petri/core/v2/provider/clients" "net" "path" "sync" @@ -44,7 +45,7 @@ type Task struct { logger *zap.Logger sshClient *ssh.Client doClient DoClient - dockerClient provider.DockerClient + dockerClient clients.DockerClient } var _ provider.TaskI = (*Task)(nil) @@ -289,7 +290,7 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in return t.runCommand(ctx, cmd) } -func waitForExec(ctx context.Context, dockerClient provider.DockerClient, execID string) (int, error) { +func waitForExec(ctx context.Context, dockerClient clients.DockerClient, execID string) (int, error) { lastExitCode := 0 ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() @@ -449,7 +450,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string return stdout.String(), stderr.String(), exitCode, nil } -func startContainerWithBlock(ctx context.Context, dockerClient provider.DockerClient, containerID string) error { +func startContainerWithBlock(ctx context.Context, dockerClient clients.DockerClient, containerID string) error { // start container if err := dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { return err diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index c6e2a515..73fd4993 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -36,7 +36,7 @@ type Provider struct { state *ProviderState stateMu sync.Mutex - dockerClient provider.DockerClient + dockerClient clients.DockerClient dockerNetworkAllocator *ipallocator.Range networkMu sync.Mutex logger *zap.Logger @@ -45,7 +45,7 @@ type Provider struct { var _ provider.ProviderI = (*Provider)(nil) func CreateProvider(ctx context.Context, logger *zap.Logger, providerName string) (*Provider, error) { - dockerClient, err := provider.NewDockerClient("") + dockerClient, err := clients.NewDockerClient("") if err != nil { return nil, err } @@ -118,7 +118,7 @@ func RestoreProvider(ctx context.Context, logger *zap.Logger, state []byte) (*Pr logger: logger, } - dockerClient, err := provider.NewDockerClient("") + dockerClient, err := clients.NewDockerClient("") if err != nil { return nil, err } diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index cbd9ceeb..1bd9696c 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -3,6 +3,7 @@ package docker_test import ( "context" "fmt" + "github.com/skip-mev/petri/core/v2/provider/clients" "sync" "testing" "time" @@ -231,7 +232,7 @@ func TestConcurrentTaskCreation(t *testing.T) { for task := range tasks { taskState := task.GetState() - dockerClient, _ := provider.NewDockerClient("") + dockerClient, _ := clients.NewDockerClient("") containerJSON, err := dockerClient.ContainerInspect(ctx, taskState.Id) require.NoError(t, err) diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index a55a7d22..98b0b3d2 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -7,6 +7,8 @@ import ( "sync" "time" + "github.com/skip-mev/petri/core/v2/provider/clients" + "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" "github.com/docker/docker/pkg/stdcopy" @@ -36,7 +38,7 @@ type Task struct { state *TaskState stateMu sync.Mutex logger *zap.Logger - dockerClient provider.DockerClient + dockerClient clients.DockerClient removeTask func(ctx context.Context, taskID string) error } @@ -329,6 +331,7 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string }, logger: t.logger.With(zap.String("temp_task", definition.Name)), dockerClient: t.dockerClient, + removeTask: t.removeTask, } err = tempTask.Start(ctx) From e96f1845ca457695e4c6932874e7df4411a1d159 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Tue, 21 Jan 2025 21:27:06 +0200 Subject: [PATCH 47/50] removeTaskFunc moved to top-level provider package --- core/provider/digitalocean/task.go | 5 +---- core/provider/digitalocean/task_test.go | 2 +- core/provider/docker/task.go | 2 +- core/provider/provider.go | 3 +++ 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 79b01eb4..76e517a3 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -34,14 +34,11 @@ type TaskState struct { SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` } -// RemoveTaskFunc is a callback function type for removing a task from its provider -type RemoveTaskFunc func(ctx context.Context, taskID int) error - type Task struct { state *TaskState stateMu sync.Mutex - removeTask RemoveTaskFunc + removeTask provider.RemoveTaskFunc logger *zap.Logger sshClient *ssh.Client doClient DoClient diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index c111692e..30ab3367 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -403,7 +403,7 @@ func TestTaskDestroy(t *testing.T) { logger: logger, dockerClient: mockDocker, doClient: mockDO, - removeTask: func(ctx context.Context, taskID int) error { + removeTask: func(ctx context.Context, taskID string) error { delete(provider.state.TaskStates, taskID) return nil }, diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index 98b0b3d2..ef29567a 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -39,7 +39,7 @@ type Task struct { stateMu sync.Mutex logger *zap.Logger dockerClient clients.DockerClient - removeTask func(ctx context.Context, taskID string) error + removeTask provider.RemoveTaskFunc } var _ provider.TaskI = (*Task)(nil) diff --git a/core/provider/provider.go b/core/provider/provider.go index 9292be95..3ac73c2e 100644 --- a/core/provider/provider.go +++ b/core/provider/provider.go @@ -10,6 +10,9 @@ import ( // TaskStatus defines the status of a task's underlying workload type TaskStatus int +// RemoveTaskFunc is a callback function type for removing a task from its provider +type RemoveTaskFunc func(ctx context.Context, taskID string) error + const ( TASK_STATUS_UNDEFINED TaskStatus = iota TASK_RUNNING From ec0afde541caa047fd555377f5da7abf816b1a25 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Wed, 22 Jan 2025 14:49:41 +0200 Subject: [PATCH 48/50] move tests to cosmos package --- cosmos/go.mod | 10 +++++++++- cosmos/go.sum | 15 +++++++++++++++ .../tests/e2e/digitalocean/do_test.go | 3 +-- {core => cosmos}/tests/e2e/docker/docker_test.go | 2 -- {core => cosmos}/tests/e2e/utils.go | 0 5 files changed, 25 insertions(+), 5 deletions(-) rename {core => cosmos}/tests/e2e/digitalocean/do_test.go (99%) rename {core => cosmos}/tests/e2e/docker/docker_test.go (99%) rename {core => cosmos}/tests/e2e/utils.go (100%) diff --git a/cosmos/go.mod b/cosmos/go.mod index bf0bca24..81aa6561 100644 --- a/cosmos/go.mod +++ b/cosmos/go.mod @@ -15,6 +15,7 @@ require ( github.com/cometbft/cometbft v0.38.12 github.com/cosmos/cosmos-sdk v0.50.10 github.com/cosmos/go-bip39 v1.0.0 + github.com/docker/docker v27.1.1+incompatible github.com/golangci/golangci-lint v1.56.2 github.com/icza/dyno v0.0.0-20230330125955-09f820a8d9c0 github.com/matoous/go-nanoid/v2 v2.1.0 @@ -97,8 +98,8 @@ require ( github.com/dgraph-io/badger/v2 v2.2007.4 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect + github.com/digitalocean/godo v1.108.0 // indirect github.com/distribution/reference v0.5.0 // indirect - github.com/docker/docker v27.1.1+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -146,6 +147,7 @@ require ( github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect @@ -154,8 +156,10 @@ require ( github.com/gostaticanalysis/nilerr v0.1.1 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-metrics v0.5.3 // indirect + github.com/hashicorp/go-retryablehttp v0.7.4 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -173,6 +177,7 @@ require ( github.com/kisielk/gotool v1.0.0 // indirect github.com/kkHAIKE/contextcheck v1.1.4 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/kulti/thelper v0.6.3 // indirect @@ -209,6 +214,7 @@ require ( github.com/opencontainers/image-spec v1.1.0-rc2 // indirect github.com/petermattis/goid v0.0.0-20231207134359-e60b3f734c67 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkg/sftp v1.13.6 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polyfloyd/go-errorlint v1.4.8 // indirect github.com/prometheus/client_golang v1.20.1 // indirect @@ -284,9 +290,11 @@ require ( golang.org/x/exp/typeparams v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/net v0.31.0 // indirect + golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.26.0 // indirect golang.org/x/text v0.20.0 // indirect + golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/cosmos/go.sum b/cosmos/go.sum index aade3e39..bcc9861a 100644 --- a/cosmos/go.sum +++ b/cosmos/go.sum @@ -242,6 +242,8 @@ github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkz github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/digitalocean/godo v1.108.0 h1:fWyMENvtxpCpva1UbKzOFnyAS04N1FNuBWWfPeTGquQ= +github.com/digitalocean/godo v1.108.0/go.mod h1:R6EmmWI8CT1+fCtjWY9UCB+L5uufuZH13wk3YhxycCs= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/docker/docker v27.1.1+incompatible h1:hO/M4MtV36kzKldqnA37IWhebRA+LnqqcqDja6kVaKY= @@ -425,6 +427,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -478,6 +482,9 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKA github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= @@ -488,6 +495,8 @@ github.com/hashicorp/go-metrics v0.5.3/go.mod h1:KEjodfebIOuBYSAe/bHTm+HChmKSxAO github.com/hashicorp/go-plugin v1.5.2 h1:aWv8eimFqWlsEiMrYZdPYl+FdHaBJSN4AWwGWfT1G2Y= github.com/hashicorp/go-plugin v1.5.2/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA= +github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -556,6 +565,8 @@ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2 github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -696,6 +707,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= +github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -1037,6 +1050,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/core/tests/e2e/digitalocean/do_test.go b/cosmos/tests/e2e/digitalocean/do_test.go similarity index 99% rename from core/tests/e2e/digitalocean/do_test.go rename to cosmos/tests/e2e/digitalocean/do_test.go index bc7b5fbe..ea8b64d1 100644 --- a/core/tests/e2e/digitalocean/do_test.go +++ b/cosmos/tests/e2e/digitalocean/do_test.go @@ -3,12 +3,11 @@ package e2e import ( "context" "flag" + "github.com/skip-mev/petri/cosmos/v2/tests/e2e" "os" "testing" "time" - "github.com/skip-mev/petri/core/v2/tests/e2e" - "github.com/skip-mev/petri/cosmos/v2/node" "github.com/cosmos/cosmos-sdk/crypto/hd" diff --git a/core/tests/e2e/docker/docker_test.go b/cosmos/tests/e2e/docker/docker_test.go similarity index 99% rename from core/tests/e2e/docker/docker_test.go rename to cosmos/tests/e2e/docker/docker_test.go index b88d1d64..629e40e3 100644 --- a/core/tests/e2e/docker/docker_test.go +++ b/cosmos/tests/e2e/docker/docker_test.go @@ -8,8 +8,6 @@ import ( "sync" "testing" - "github.com/skip-mev/petri/core/v2/tests/e2e" - "github.com/skip-mev/petri/cosmos/v2/node" "github.com/cosmos/cosmos-sdk/crypto/hd" diff --git a/core/tests/e2e/utils.go b/cosmos/tests/e2e/utils.go similarity index 100% rename from core/tests/e2e/utils.go rename to cosmos/tests/e2e/utils.go From d930bbc14689d193a7948974c3177fa61c871238 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 23 Jan 2025 02:39:39 +0200 Subject: [PATCH 49/50] upgrade to v3 --- core/go.mod | 3 +-- core/go.sum | 2 -- core/provider/digitalocean/droplet.go | 3 ++- core/provider/digitalocean/provider.go | 9 +++++---- core/provider/digitalocean/provider_test.go | 11 ++++++----- core/provider/digitalocean/task.go | 3 ++- core/provider/digitalocean/task_test.go | 6 +++--- core/provider/docker/provider.go | 2 +- core/provider/docker/provider_test.go | 3 ++- core/provider/docker/task.go | 2 +- .../examples/digitalocean_simapp.go | 12 ++++++------ cosmos/tests/e2e/digitalocean/do_test.go | 13 +++++++------ cosmos/tests/e2e/docker/docker_test.go | 15 +++++++-------- cosmos/tests/e2e/utils.go | 6 +++--- 14 files changed, 46 insertions(+), 44 deletions(-) rename {core/provider/digitalocean => cosmos}/examples/digitalocean_simapp.go (92%) diff --git a/core/go.mod b/core/go.mod index 9fd4218f..39799c58 100644 --- a/core/go.mod +++ b/core/go.mod @@ -14,9 +14,9 @@ require ( github.com/go-rod/rod v0.114.6 github.com/golangci/golangci-lint v1.56.2 github.com/matoous/go-nanoid/v2 v2.1.0 + github.com/opencontainers/image-spec v1.1.0-rc2 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 - github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/spf13/afero v1.11.0 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.26.0 @@ -199,7 +199,6 @@ require ( github.com/oasisprotocol/curve25519-voi v0.0.0-20230904125328-1f23a7beb09a // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/petermattis/goid v0.0.0-20231207134359-e60b3f734c67 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/core/go.sum b/core/go.sum index 91d0c9e2..1f840f8b 100644 --- a/core/go.sum +++ b/core/go.sum @@ -742,8 +742,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4= -github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/quasilyte/go-ruleguard v0.4.0 h1:DyM6r+TKL+xbKB4Nm7Afd1IQh9kEUKQs2pboWGKtvQo= github.com/quasilyte/go-ruleguard v0.4.0/go.mod h1:Eu76Z/R8IXtViWUIHkE3p8gdH3/PKk1eh3YGfaEof10= github.com/quasilyte/gogrep v0.5.0 h1:eTKODPXbI8ffJMN+W2aE0+oL0z/nh8/5eNdiO34SOAo= diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 7c63a9b4..d27934b2 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -3,9 +3,10 @@ package digitalocean import ( "context" "fmt" - "github.com/skip-mev/petri/core/v2/provider/clients" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/pkg/errors" "github.com/digitalocean/godo" diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go index 2ee76bee..2e0b6961 100644 --- a/core/provider/digitalocean/provider.go +++ b/core/provider/digitalocean/provider.go @@ -6,12 +6,13 @@ import ( "errors" "fmt" "strconv" - "github.com/docker/docker/api/types/image" - "github.com/skip-mev/petri/core/v2/provider/clients" "strings" "sync" "time" + "github.com/docker/docker/api/types/image" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/digitalocean/godo" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" @@ -19,8 +20,8 @@ import ( "go.uber.org/zap" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/util" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/util" ) var _ provider.ProviderI = (*Provider)(nil) diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go index 9e6f72b9..baa8815b 100644 --- a/core/provider/digitalocean/provider_test.go +++ b/core/provider/digitalocean/provider_test.go @@ -3,12 +3,13 @@ package digitalocean import ( "context" "fmt" - "github.com/skip-mev/petri/core/v2/provider/clients" "sync" "testing" "time" - "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/skip-mev/petri/core/v3/provider/clients" + + "github.com/skip-mev/petri/core/v3/provider/digitalocean/mocks" "github.com/digitalocean/godo" "github.com/docker/docker/api/types" @@ -22,9 +23,9 @@ import ( "github.com/stretchr/testify/require" "go.uber.org/zap" - "github.com/skip-mev/petri/core/v2/provider" - dockerMocks "github.com/skip-mev/petri/core/v2/provider/mocks" - "github.com/skip-mev/petri/core/v2/util" + "github.com/skip-mev/petri/core/v3/provider" + dockerMocks "github.com/skip-mev/petri/core/v3/provider/mocks" + "github.com/skip-mev/petri/core/v3/util" ) func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoClient, *dockerMocks.DockerClient) { diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index 76e517a3..b3edeb64 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -4,12 +4,13 @@ import ( "bytes" "context" "fmt" - "github.com/skip-mev/petri/core/v2/provider/clients" "net" "path" "sync" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "golang.org/x/crypto/ssh" "github.com/docker/docker/api/types/container" diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go index 30ab3367..07bc58d0 100644 --- a/core/provider/digitalocean/task_test.go +++ b/core/provider/digitalocean/task_test.go @@ -16,11 +16,11 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" "github.com/docker/docker/api/types/network" - dockerMocks "github.com/skip-mev/petri/core/v2/provider/mocks" + dockerMocks "github.com/skip-mev/petri/core/v3/provider/mocks" specs "github.com/opencontainers/image-spec/specs-go/v1" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/provider/digitalocean/mocks" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/digitalocean/mocks" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.uber.org/zap" diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index 73fd4993..ffe69cd4 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -4,12 +4,12 @@ import ( "context" "encoding/json" "fmt" - "io" "net" "sync" "github.com/docker/docker/api/types/image" "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/clients" "github.com/cilium/ipam/service/ipallocator" "github.com/docker/docker/api/types/network" diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index 1bd9696c..01b25622 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -3,11 +3,12 @@ package docker_test import ( "context" "fmt" - "github.com/skip-mev/petri/core/v2/provider/clients" "sync" "testing" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index ef29567a..6ee04da8 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/skip-mev/petri/core/v2/provider/clients" + "github.com/skip-mev/petri/core/v3/provider/clients" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" diff --git a/core/provider/digitalocean/examples/digitalocean_simapp.go b/cosmos/examples/digitalocean_simapp.go similarity index 92% rename from core/provider/digitalocean/examples/digitalocean_simapp.go rename to cosmos/examples/digitalocean_simapp.go index b1d45971..6d3d7789 100644 --- a/core/provider/digitalocean/examples/digitalocean_simapp.go +++ b/cosmos/examples/digitalocean_simapp.go @@ -1,7 +1,8 @@ -package main +package examples import ( "context" + "github.com/skip-mev/petri/core/v3/provider/digitalocean" "io" "net/http" "os" @@ -9,12 +10,11 @@ import ( "github.com/cosmos/cosmos-sdk/crypto/hd" - "github.com/skip-mev/petri/cosmos/v2/chain" - "github.com/skip-mev/petri/cosmos/v2/node" + "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/skip-mev/petri/cosmos/v3/node" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/provider/digitalocean" - petritypes "github.com/skip-mev/petri/core/v2/types" + "github.com/skip-mev/petri/core/v3/provider" + petritypes "github.com/skip-mev/petri/core/v3/types" "go.uber.org/zap" ) diff --git a/cosmos/tests/e2e/digitalocean/do_test.go b/cosmos/tests/e2e/digitalocean/do_test.go index ea8b64d1..1f6e6a70 100644 --- a/cosmos/tests/e2e/digitalocean/do_test.go +++ b/cosmos/tests/e2e/digitalocean/do_test.go @@ -3,18 +3,19 @@ package e2e import ( "context" "flag" - "github.com/skip-mev/petri/cosmos/v2/tests/e2e" "os" "testing" "time" - "github.com/skip-mev/petri/cosmos/v2/node" + "github.com/skip-mev/petri/cosmos/v3/tests/e2e" + + "github.com/skip-mev/petri/cosmos/v3/node" "github.com/cosmos/cosmos-sdk/crypto/hd" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/provider/digitalocean" - "github.com/skip-mev/petri/core/v2/types" - cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/digitalocean" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" "github.com/stretchr/testify/require" "go.uber.org/zap" ) diff --git a/cosmos/tests/e2e/docker/docker_test.go b/cosmos/tests/e2e/docker/docker_test.go index 629e40e3..44b78856 100644 --- a/cosmos/tests/e2e/docker/docker_test.go +++ b/cosmos/tests/e2e/docker/docker_test.go @@ -3,22 +3,21 @@ package e2e import ( "context" "flag" - "fmt" "os" - "sync" "testing" - "github.com/skip-mev/petri/cosmos/v2/node" + "github.com/skip-mev/petri/cosmos/v3/node" + "github.com/skip-mev/petri/cosmos/v3/tests/e2e" "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" gonanoid "github.com/matoous/go-nanoid/v2" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/provider/digitalocean" - "github.com/skip-mev/petri/core/v2/provider/docker" - "github.com/skip-mev/petri/core/v2/types" - cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/digitalocean" + "github.com/skip-mev/petri/core/v3/provider/docker" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" "github.com/stretchr/testify/require" "go.uber.org/zap" ) diff --git a/cosmos/tests/e2e/utils.go b/cosmos/tests/e2e/utils.go index ac83c07a..9f392c95 100644 --- a/cosmos/tests/e2e/utils.go +++ b/cosmos/tests/e2e/utils.go @@ -9,9 +9,9 @@ import ( "sync" "testing" - "github.com/skip-mev/petri/core/v2/provider" - "github.com/skip-mev/petri/core/v2/types" - cosmoschain "github.com/skip-mev/petri/cosmos/v2/chain" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" "github.com/stretchr/testify/require" "go.uber.org/zap" ) From 67a425c8cd718f9a5e39940567e3f6c9fa9fac70 Mon Sep 17 00:00:00 2001 From: nadimabdelaziz <nadeem.abdelaziz99@gmail.com> Date: Thu, 23 Jan 2025 02:50:52 +0200 Subject: [PATCH 50/50] bad rebase --- cosmos/tests/e2e/digitalocean/do_test.go | 9 ++++---- cosmos/tests/e2e/docker/docker_test.go | 27 ++++-------------------- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/cosmos/tests/e2e/digitalocean/do_test.go b/cosmos/tests/e2e/digitalocean/do_test.go index 1f6e6a70..b2a27618 100644 --- a/cosmos/tests/e2e/digitalocean/do_test.go +++ b/cosmos/tests/e2e/digitalocean/do_test.go @@ -7,15 +7,14 @@ import ( "testing" "time" - "github.com/skip-mev/petri/cosmos/v3/tests/e2e" - - "github.com/skip-mev/petri/cosmos/v3/node" - - "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/skip-mev/petri/core/v3/provider" "github.com/skip-mev/petri/core/v3/provider/digitalocean" "github.com/skip-mev/petri/core/v3/types" cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/skip-mev/petri/cosmos/v3/node" + "github.com/skip-mev/petri/cosmos/v3/tests/e2e" + + "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/stretchr/testify/require" "go.uber.org/zap" ) diff --git a/cosmos/tests/e2e/docker/docker_test.go b/cosmos/tests/e2e/docker/docker_test.go index 44b78856..4e58bc00 100644 --- a/cosmos/tests/e2e/docker/docker_test.go +++ b/cosmos/tests/e2e/docker/docker_test.go @@ -3,21 +3,18 @@ package e2e import ( "context" "flag" - "os" "testing" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/docker" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" "github.com/skip-mev/petri/cosmos/v3/node" "github.com/skip-mev/petri/cosmos/v3/tests/e2e" "github.com/cosmos/cosmos-sdk/crypto/hd" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" - gonanoid "github.com/matoous/go-nanoid/v2" - "github.com/skip-mev/petri/core/v3/provider" - "github.com/skip-mev/petri/core/v3/provider/digitalocean" - "github.com/skip-mev/petri/core/v3/provider/docker" - "github.com/skip-mev/petri/core/v3/types" - cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" "github.com/stretchr/testify/require" "go.uber.org/zap" ) @@ -51,17 +48,6 @@ var ( DerivationFn: hd.Secp256k1.Derive(), GenerationFn: hd.Secp256k1.Generate(), }, - NodeOptions: types.NodeOptions{ - NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig types.NodeConfig) provider.TaskDefinition { - doConfig := digitalocean.DigitalOceanTaskConfig{ - "size": "s-2vcpu-4gb", - "region": "ams3", - "image_id": os.Getenv("DO_IMAGE_ID"), - } - def.ProviderSpecificConfig = doConfig - return def - }, - }, } numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") @@ -76,13 +62,8 @@ func TestDockerE2E(t *testing.T) { ctx := context.Background() logger, _ := zap.NewDevelopment() - providerName := gonanoid.MustGenerate("abcdefghijklqmnoqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890", 10) - - p, err := docker.CreateProvider(ctx, logger, providerName) - require.NoError(t, err) defer func() { - require.NoError(t, p.Teardown(ctx)) dockerClient, err := client.NewClientWithOpts() if err != nil { t.Logf("Failed to create Docker client for volume cleanup: %v", err)