diff --git a/core/go.mod b/core/go.mod index 9fd4218..39799c5 100644 --- a/core/go.mod +++ b/core/go.mod @@ -14,9 +14,9 @@ require ( github.com/go-rod/rod v0.114.6 github.com/golangci/golangci-lint v1.56.2 github.com/matoous/go-nanoid/v2 v2.1.0 + github.com/opencontainers/image-spec v1.1.0-rc2 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 - github.com/puzpuzpuz/xsync/v3 v3.4.0 github.com/spf13/afero v1.11.0 github.com/stretchr/testify v1.9.0 go.uber.org/zap v1.26.0 @@ -199,7 +199,6 @@ require ( github.com/oasisprotocol/curve25519-voi v0.0.0-20230904125328-1f23a7beb09a // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/petermattis/goid v0.0.0-20231207134359-e60b3f734c67 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect diff --git a/core/go.sum b/core/go.sum index 91d0c9e..1f840f8 100644 --- a/core/go.sum +++ b/core/go.sum @@ -742,8 +742,6 @@ github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1 github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/puzpuzpuz/xsync/v3 v3.4.0 h1:DuVBAdXuGFHv8adVXjWWZ63pJq+NRXOWVXlKDBZ+mJ4= -github.com/puzpuzpuz/xsync/v3 v3.4.0/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= github.com/quasilyte/go-ruleguard v0.4.0 h1:DyM6r+TKL+xbKB4Nm7Afd1IQh9kEUKQs2pboWGKtvQo= github.com/quasilyte/go-ruleguard v0.4.0/go.mod h1:Eu76Z/R8IXtViWUIHkE3p8gdH3/PKk1eh3YGfaEof10= github.com/quasilyte/gogrep v0.5.0 h1:eTKODPXbI8ffJMN+W2aE0+oL0z/nh8/5eNdiO34SOAo= diff --git a/core/provider/clients/docker_client.go b/core/provider/clients/docker_client.go new file mode 100644 index 0000000..44c3394 --- /dev/null +++ b/core/provider/clients/docker_client.go @@ -0,0 +1,194 @@ +package clients + +import ( + "context" + "fmt" + "go.uber.org/zap" + "io" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/network" + "github.com/docker/docker/api/types/volume" + dockerclient "github.com/docker/docker/client" + specs "github.com/opencontainers/image-spec/specs-go/v1" +) + +// DockerClient is a unified interface for interacting with Docker +// It combines functionality needed by both the Docker and DigitalOcean providers +type DockerClient interface { + // Container Operations + ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) + ContainerStart(ctx context.Context, container string, options container.StartOptions) error + ContainerStop(ctx context.Context, container string, options container.StopOptions) error + ContainerRemove(ctx context.Context, container string, options container.RemoveOptions) error + ContainerInspect(ctx context.Context, container string) (types.ContainerJSON, error) + ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) + ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) + + // Container Exec Operations + ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) + ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) + ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) + + // Container File Operations + CopyToContainer(ctx context.Context, container, path string, content io.Reader, options container.CopyToContainerOptions) error + CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, container.PathStat, error) + ContainerLogs(ctx context.Context, container string, options container.LogsOptions) (io.ReadCloser, error) + + // Image Operations + ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) + ImagePull(ctx context.Context, logger *zap.Logger, refStr string, options image.PullOptions) error + + // Volume Operations + VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) + VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) + VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) + VolumeRemove(ctx context.Context, volumeID string, force bool) error + + // Network Operations + NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) + NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) + NetworkRemove(ctx context.Context, networkID string) error + + // System Operations + Ping(ctx context.Context) (types.Ping, error) + Close() error +} + +// defaultDockerClient is the default implementation of DockerClient interface +type defaultDockerClient struct { + client *dockerclient.Client +} + +func NewDockerClient(host string) (DockerClient, error) { + // If host is empty, use default Docker socket + if host == "" { + client, err := dockerclient.NewClientWithOpts() + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil + } + + host = fmt.Sprintf("tcp://%s:2375", host) + + client, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(host)) + if err != nil { + return nil, err + } + return &defaultDockerClient{client: client}, nil +} + +func (d *defaultDockerClient) Ping(ctx context.Context) (types.Ping, error) { + return d.client.Ping(ctx) +} + +func (d *defaultDockerClient) ImageInspectWithRaw(ctx context.Context, image string) (types.ImageInspect, []byte, error) { + return d.client.ImageInspectWithRaw(ctx, image) +} + +func (d *defaultDockerClient) ImagePull(ctx context.Context, logger *zap.Logger, ref string, options image.PullOptions) error { + _, _, err := d.client.ImageInspectWithRaw(ctx, ref) + if err != nil { + logger.Info("pulling image", zap.String("image", ref)) + resp, err := d.client.ImagePull(ctx, ref, options) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + + defer resp.Close() + // throw away the image pull stdout response + _, err = io.Copy(io.Discard, resp) + if err != nil { + return fmt.Errorf("failed to pull docker image: %w", err) + } + return nil + } + return nil +} + +func (d *defaultDockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *specs.Platform, containerName string) (container.CreateResponse, error) { + return d.client.ContainerCreate(ctx, config, hostConfig, networkingConfig, platform, containerName) +} + +func (d *defaultDockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + return d.client.ContainerList(ctx, options) +} + +func (d *defaultDockerClient) ContainerStart(ctx context.Context, containerID string, options container.StartOptions) error { + return d.client.ContainerStart(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerStop(ctx context.Context, containerID string, options container.StopOptions) error { + return d.client.ContainerStop(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerInspect(ctx context.Context, containerID string) (types.ContainerJSON, error) { + return d.client.ContainerInspect(ctx, containerID) +} + +func (d *defaultDockerClient) ContainerExecCreate(ctx context.Context, container string, config container.ExecOptions) (types.IDResponse, error) { + return d.client.ContainerExecCreate(ctx, container, config) +} + +func (d *defaultDockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) { + return d.client.ContainerExecAttach(ctx, execID, config) +} + +func (d *defaultDockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { + return d.client.ContainerExecInspect(ctx, execID) +} + +func (d *defaultDockerClient) ContainerRemove(ctx context.Context, containerID string, options container.RemoveOptions) error { + return d.client.ContainerRemove(ctx, containerID, options) +} + +func (d *defaultDockerClient) ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) { + return d.client.ContainerWait(ctx, containerID, condition) +} + +func (d *defaultDockerClient) ContainerLogs(ctx context.Context, container string, options container.LogsOptions) (io.ReadCloser, error) { + return d.client.ContainerLogs(ctx, container, options) +} + +func (d *defaultDockerClient) CopyToContainer(ctx context.Context, container, path string, content io.Reader, options container.CopyToContainerOptions) error { + return d.client.CopyToContainer(ctx, container, path, content, options) +} + +func (d *defaultDockerClient) CopyFromContainer(ctx context.Context, container, srcPath string) (io.ReadCloser, container.PathStat, error) { + return d.client.CopyFromContainer(ctx, container, srcPath) +} + +func (d *defaultDockerClient) VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) { + return d.client.VolumeCreate(ctx, options) +} + +func (d *defaultDockerClient) VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) { + return d.client.VolumeInspect(ctx, volumeID) +} + +func (d *defaultDockerClient) VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) { + return d.client.VolumeList(ctx, options) +} + +func (d *defaultDockerClient) VolumeRemove(ctx context.Context, volumeID string, force bool) error { + return d.client.VolumeRemove(ctx, volumeID, force) +} + +func (d *defaultDockerClient) NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) { + return d.client.NetworkCreate(ctx, name, options) +} + +func (d *defaultDockerClient) NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) { + return d.client.NetworkInspect(ctx, networkID, options) +} + +func (d *defaultDockerClient) NetworkRemove(ctx context.Context, networkID string) error { + return d.client.NetworkRemove(ctx, networkID) +} + +func (d *defaultDockerClient) Close() error { + return d.client.Close() +} diff --git a/core/provider/digitalocean/client.go b/core/provider/digitalocean/client.go new file mode 100644 index 0000000..9375dc3 --- /dev/null +++ b/core/provider/digitalocean/client.go @@ -0,0 +1,128 @@ +package digitalocean + +import ( + "context" + "fmt" + + "github.com/digitalocean/godo" +) + +// DoClient defines the interface for DigitalOcean API operations used by the provider +type DoClient interface { + // Droplet operations + CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) + GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) + DeleteDropletByTag(ctx context.Context, tag string) error + DeleteDropletByID(ctx context.Context, id int) error + + // Firewall operations + CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) + DeleteFirewall(ctx context.Context, firewallID string) error + + // SSH Key operations + CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) + DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error + GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) + + // Tag operations + CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) + DeleteTag(ctx context.Context, tag string) error +} + +// godoClient implements the DoClient interface using the actual godo.Client +type godoClient struct { + *godo.Client +} + +func NewGodoClient(token string) DoClient { + return &godoClient{Client: godo.NewFromToken(token)} +} + +func checkResponse(res *godo.Response, err error) error { + if err != nil { + return err + } + + if res.StatusCode > 299 || res.StatusCode < 200 { + return fmt.Errorf("unexpected status code: %d", res.StatusCode) + } + + return nil +} + +// Droplet operations +func (c *godoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) { + droplet, res, err := c.Droplets.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return droplet, nil +} + +func (c *godoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) { + droplet, res, err := c.Droplets.Get(ctx, dropletID) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return droplet, nil +} + +func (c *godoClient) DeleteDropletByTag(ctx context.Context, tag string) error { + res, err := c.Droplets.DeleteByTag(ctx, tag) + return checkResponse(res, err) +} + +func (c *godoClient) DeleteDropletByID(ctx context.Context, id int) error { + res, err := c.Droplets.Delete(ctx, id) + return checkResponse(res, err) +} + +// Firewall operations +func (c *godoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) { + firewall, res, err := c.Firewalls.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return firewall, nil +} + +func (c *godoClient) DeleteFirewall(ctx context.Context, firewallID string) error { + res, err := c.Firewalls.Delete(ctx, firewallID) + return checkResponse(res, err) +} + +// SSH Key operations +func (c *godoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) { + key, res, err := c.Keys.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return key, nil +} + +func (c *godoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error { + res, err := c.Keys.DeleteByFingerprint(ctx, fingerprint) + return checkResponse(res, err) +} + +func (c *godoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) { + key, res, err := c.Keys.GetByFingerprint(ctx, fingerprint) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return key, nil +} + +// Tag operations +func (c *godoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) { + tag, res, err := c.Tags.Create(ctx, req) + if err := checkResponse(res, err); err != nil { + return nil, err + } + return tag, nil +} + +func (c *godoClient) DeleteTag(ctx context.Context, tag string) error { + res, err := c.Tags.Delete(ctx, tag) + return checkResponse(res, err) +} diff --git a/core/provider/digitalocean/digitalocean_provider.go b/core/provider/digitalocean/digitalocean_provider.go deleted file mode 100644 index 887244c..0000000 --- a/core/provider/digitalocean/digitalocean_provider.go +++ /dev/null @@ -1,165 +0,0 @@ -package digitalocean - -import ( - "context" - "fmt" - "strings" - - "github.com/digitalocean/godo" - "go.uber.org/zap" - - xsync "github.com/puzpuzpuz/xsync/v3" - - "github.com/skip-mev/petri/core/v3/provider" - "github.com/skip-mev/petri/core/v3/util" - "golang.org/x/crypto/ssh" -) - -var _ provider.Provider = (*Provider)(nil) - -const ( - providerLabelName = "petri-provider" -) - -type Provider struct { - logger *zap.Logger - name string - doClient *godo.Client - petriTag string - - userIPs []string - - sshKeyPair *SSHKeyPair - - sshClients *xsync.MapOf[string, *ssh.Client] - - firewallID string -} - -// NewDigitalOceanProvider creates a provider that implements the Provider interface for DigitalOcean. -// Token is the DigitalOcean API token -func NewDigitalOceanProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { - doClient := godo.NewFromToken(token) - - if sshKeyPair == nil { - newSshKeyPair, err := MakeSSHKeyPair() - if err != nil { - return nil, err - } - sshKeyPair = newSshKeyPair - } - - userIPs, err := getUserIPs(ctx) - if err != nil { - return nil, err - } - - userIPs = append(userIPs, additionalUserIPS...) - - digitalOceanProvider := &Provider{ - logger: logger.Named("digitalocean_provider"), - name: providerName, - doClient: doClient, - petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)), - - userIPs: userIPs, - - sshClients: xsync.NewMapOf[string, *ssh.Client](), - sshKeyPair: sshKeyPair, - } - - _, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag) - if err != nil { - return nil, err - } - - firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) - if err != nil { - return nil, fmt.Errorf("failed to create firewall: %w", err) - } - - digitalOceanProvider.firewallID = firewall.ID - - //TODO(Zygimantass): TOCTOU issue - if key, _, err := doClient.Keys.GetByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { - _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) - if err != nil { - if !strings.Contains(err.Error(), "422") { - return nil, err - } - } - } - - return digitalOceanProvider, nil -} - -func (p *Provider) Teardown(ctx context.Context) error { - p.logger.Info("tearing down DigitalOcean provider") - - if err := p.teardownTasks(ctx); err != nil { - return err - } - if err := p.teardownFirewall(ctx); err != nil { - return err - } - if err := p.teardownSSHKey(ctx); err != nil { - return err - } - if err := p.teardownTag(ctx); err != nil { - return err - } - - return nil -} - -func (p *Provider) teardownTasks(ctx context.Context) error { - res, err := p.doClient.Droplets.DeleteByTag(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownFirewall(ctx context.Context) error { - res, err := p.doClient.Firewalls.Delete(ctx, p.firewallID) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownSSHKey(ctx context.Context) error { - res, err := p.doClient.Keys.DeleteByFingerprint(ctx, p.sshKeyPair.Fingerprint) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) teardownTag(ctx context.Context) error { - res, err := p.doClient.Tags.Delete(ctx, p.petriTag) - if err != nil { - return err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} diff --git a/core/provider/digitalocean/droplet.go b/core/provider/digitalocean/droplet.go index 8166baa..d27934b 100644 --- a/core/provider/digitalocean/droplet.go +++ b/core/provider/digitalocean/droplet.go @@ -5,10 +5,11 @@ import ( "fmt" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/pkg/errors" "github.com/digitalocean/godo" - dockerclient "github.com/docker/docker/client" "go.uber.org/zap" "golang.org/x/crypto/ssh" @@ -20,11 +21,6 @@ import ( _ "embed" ) -// nolint -// -//go:embed files/docker-cloud-init.yaml -var dockerCloudInit string - func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDefinition) (*godo.Droplet, error) { if err := definition.ValidateBasic(); err != nil { return nil, fmt.Errorf("failed to validate task definition: %w", err) @@ -42,8 +38,9 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe return nil, fmt.Errorf("failed to parse image ID: %w", err) } + state := p.GetState() req := &godo.DropletCreateRequest{ - Name: fmt.Sprintf("%s-%s", p.petriTag, definition.Name), + Name: fmt.Sprintf("%s-%s", state.PetriTag, definition.Name), Region: doConfig["region"], Size: doConfig["size"], Image: godo.DropletCreateImage{ @@ -51,25 +48,21 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe }, SSHKeys: []godo.DropletCreateSSHKey{ { - Fingerprint: p.sshKeyPair.Fingerprint, + Fingerprint: state.SSHKeyPair.Fingerprint, }, }, - Tags: []string{p.petriTag}, + Tags: []string{state.PetriTag}, } - droplet, res, err := p.doClient.Droplets.Create(ctx, req) + droplet, err := p.doClient.CreateDroplet(ctx, req) if err != nil { return nil, err } - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - start := time.Now() err = util.WaitForCondition(ctx, time.Second*600, time.Millisecond*300, func() (bool, error) { - d, _, err := p.doClient.Droplets.Get(ctx, droplet.ID) + d, err := p.doClient.GetDroplet(ctx, droplet.ID) if err != nil { return false, err } @@ -80,16 +73,19 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe ip, err := d.PublicIPv4() if err != nil { - return false, nil + return false, err } - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) - if err != nil { - p.logger.Error("failed to create docker client", zap.Error(err)) - return false, err + if p.dockerClients[ip] == nil { + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + if err != nil { + p.logger.Error("failed to create docker client", zap.Error(err)) + return false, err + } + p.dockerClients[ip] = dockerClient } - _, err = dockerClient.Ping(ctx) + _, err = p.dockerClients[ip].Ping(ctx) if err != nil { return false, nil } @@ -104,83 +100,47 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe end := time.Now() - p.logger.Info("droplet %s is ready after %s", zap.String("name", droplet.Name), zap.Duration("took", end.Sub(start))) + p.logger.Info("droplet is ready after", zap.String("droplet_name", droplet.Name), zap.Duration("startup_time", end.Sub(start))) return droplet, nil } -func (p *Provider) deleteDroplet(ctx context.Context, name string) error { - droplet, err := p.getDroplet(ctx, name) - - if err != nil { - return err - } - - res, err := p.doClient.Droplets.Delete(ctx, droplet.ID) +func (t *Task) deleteDroplet(ctx context.Context) error { + droplet, err := t.getDroplet(ctx) if err != nil { return err } - if res.StatusCode > 299 || res.StatusCode < 200 { - return fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return nil -} - -func (p *Provider) getDroplet(ctx context.Context, name string) (*godo.Droplet, error) { - // TODO(Zygimantass): this change assumes that all Petri droplets are unique by name - // which should be technically true, but there might be edge cases where it's not. - droplets, res, err := p.doClient.Droplets.ListByName(ctx, name, nil) - - if err != nil { - return nil, err - } - - if res.StatusCode < 200 || res.StatusCode > 299 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - if len(droplets) == 0 { - return nil, fmt.Errorf("could not find droplet") - } - - return &droplets[0], nil + return t.doClient.DeleteDropletByID(ctx, droplet.ID) } -func (p *Provider) getDropletDockerClient(ctx context.Context, taskName string) (*dockerclient.Client, error) { - ip, err := p.GetIP(ctx, taskName) - if err != nil { - return nil, err - } - - dockerClient, err := dockerclient.NewClientWithOpts(dockerclient.WithHost(fmt.Sprintf("tcp://%s:2375", ip))) +func (t *Task) getDroplet(ctx context.Context) (*godo.Droplet, error) { + dropletId, err := strconv.Atoi(t.GetState().ID) if err != nil { return nil, err } - - return dockerClient, nil + return t.doClient.GetDroplet(ctx, dropletId) } -func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { - if _, err := p.getDroplet(ctx, taskName); err != nil { +func (t *Task) getDropletSSHClient(ctx context.Context, taskName string) (*ssh.Client, error) { + if _, err := t.getDroplet(ctx); err != nil { return nil, fmt.Errorf("droplet %s does not exist", taskName) } - if client, ok := p.sshClients.Load(taskName); ok { - status, _, err := client.SendRequest("ping", true, []byte("ping")) + if t.sshClient != nil { + status, _, err := t.sshClient.SendRequest("ping", true, []byte("ping")) if err == nil && status { - return client, nil + return t.sshClient, nil } } - ip, err := p.GetIP(ctx, taskName) + ip, err := t.GetIP(ctx) if err != nil { return nil, err } - parsedSSHKey, err := ssh.ParsePrivateKey([]byte(p.sshKeyPair.PrivateKey)) + parsedSSHKey, err := ssh.ParsePrivateKey([]byte(t.GetState().SSHKeyPair.PrivateKey)) if err != nil { return nil, err } @@ -202,7 +162,5 @@ func (p *Provider) getDropletSSHClient(ctx context.Context, taskName string) (*s return nil, err } - p.sshClients.Store(taskName, client) - return client, nil } diff --git a/core/provider/digitalocean/files/docker-cloud-init.yaml b/core/provider/digitalocean/files/docker-cloud-init.yaml deleted file mode 100644 index eaf85b0..0000000 --- a/core/provider/digitalocean/files/docker-cloud-init.yaml +++ /dev/null @@ -1,37 +0,0 @@ -#cloud-config - -package_update: true -package_upgrade: true - -# create the docker group -groups: - - docker - -# Setup Docker daemon to listen on tcp and unix socket -write_files: - - path: /etc/sysctl.d/enabled_ipv4_forwarding.conf - content: | - net.ipv4.conf.all.forwarding=1 - - path: /etc/docker/daemon.json - content: | - { - "hosts": ["unix:///var/run/docker.sock", "tcp://0.0.0.0:2375"] - } - owner: root:root - permissions: '0644' - - path: /etc/systemd/system/docker.service.d/override.conf - content: | - [Service] - ExecStart= - ExecStart=/usr/bin/dockerd --containerd=/run/containerd/containerd.sock --tls=false - owner: root:root - - -# Create a directory for Docker volumes -runcmd: - - curl -fsSL https://get.docker.com | sh - - mkdir /docker_volumes - - chmod 755 /docker_volumes - - chown root:docker /docker_volumes - - systemctl daemon-reload - - systemctl restart docker \ No newline at end of file diff --git a/core/provider/digitalocean/firewall.go b/core/provider/digitalocean/firewall.go index 82917b5..ec819ad 100644 --- a/core/provider/digitalocean/firewall.go +++ b/core/provider/digitalocean/firewall.go @@ -8,9 +8,10 @@ import ( ) func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*godo.Firewall, error) { + state := p.GetState() req := &godo.FirewallRequest{ - Name: fmt.Sprintf("%s-firewall", p.petriTag), - Tags: []string{p.petriTag}, + Name: fmt.Sprintf("%s-firewall", state.PetriTag), + Tags: []string{state.PetriTag}, OutboundRules: []godo.OutboundRule{ { Protocol: "tcp", @@ -39,7 +40,7 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "tcp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.petriTag}, + Tags: []string{state.PetriTag}, Addresses: allowedIPs, }, }, @@ -47,21 +48,12 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go Protocol: "udp", PortRange: "0", Sources: &godo.Sources{ - Tags: []string{p.petriTag}, + Tags: []string{state.PetriTag}, Addresses: allowedIPs, }, }, }, } - firewall, res, err := p.doClient.Firewalls.Create(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return firewall, nil + return p.doClient.CreateFirewall(ctx, req) } diff --git a/core/provider/digitalocean/mocks/do_client_mock.go b/core/provider/digitalocean/mocks/do_client_mock.go new file mode 100644 index 0000000..302823c --- /dev/null +++ b/core/provider/digitalocean/mocks/do_client_mock.go @@ -0,0 +1,300 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + godo "github.com/digitalocean/godo" + + mock "github.com/stretchr/testify/mock" +) + +// DoClient is an autogenerated mock type for the DoClient type +type DoClient struct { + mock.Mock +} + +// CreateDroplet provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateDroplet(ctx context.Context, req *godo.DropletCreateRequest) (*godo.Droplet, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateDroplet") + } + + var r0 *godo.Droplet + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) (*godo.Droplet, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.DropletCreateRequest) *godo.Droplet); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.DropletCreateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateFirewall provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateFirewall(ctx context.Context, req *godo.FirewallRequest) (*godo.Firewall, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateFirewall") + } + + var r0 *godo.Firewall + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) (*godo.Firewall, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.FirewallRequest) *godo.Firewall); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Firewall) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.FirewallRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateKey provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateKey(ctx context.Context, req *godo.KeyCreateRequest) (*godo.Key, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateKey") + } + + var r0 *godo.Key + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) (*godo.Key, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.KeyCreateRequest) *godo.Key); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.KeyCreateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreateTag provides a mock function with given fields: ctx, req +func (_m *DoClient) CreateTag(ctx context.Context, req *godo.TagCreateRequest) (*godo.Tag, error) { + ret := _m.Called(ctx, req) + + if len(ret) == 0 { + panic("no return value specified for CreateTag") + } + + var r0 *godo.Tag + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) (*godo.Tag, error)); ok { + return rf(ctx, req) + } + if rf, ok := ret.Get(0).(func(context.Context, *godo.TagCreateRequest) *godo.Tag); ok { + r0 = rf(ctx, req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Tag) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, *godo.TagCreateRequest) error); ok { + r1 = rf(ctx, req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteDropletByID provides a mock function with given fields: ctx, id +func (_m *DoClient) DeleteDropletByID(ctx context.Context, id int) error { + ret := _m.Called(ctx, id) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByID") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int) error); ok { + r0 = rf(ctx, id) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteDropletByTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteDropletByTag(ctx context.Context, tag string) error { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteDropletByTag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, tag) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteFirewall provides a mock function with given fields: ctx, firewallID +func (_m *DoClient) DeleteFirewall(ctx context.Context, firewallID string) error { + ret := _m.Called(ctx, firewallID) + + if len(ret) == 0 { + panic("no return value specified for DeleteFirewall") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, firewallID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) DeleteKeyByFingerprint(ctx context.Context, fingerprint string) error { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for DeleteKeyByFingerprint") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, fingerprint) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteTag provides a mock function with given fields: ctx, tag +func (_m *DoClient) DeleteTag(ctx context.Context, tag string) error { + ret := _m.Called(ctx, tag) + + if len(ret) == 0 { + panic("no return value specified for DeleteTag") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, tag) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetDroplet provides a mock function with given fields: ctx, dropletID +func (_m *DoClient) GetDroplet(ctx context.Context, dropletID int) (*godo.Droplet, error) { + ret := _m.Called(ctx, dropletID) + + if len(ret) == 0 { + panic("no return value specified for GetDroplet") + } + + var r0 *godo.Droplet + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, int) (*godo.Droplet, error)); ok { + return rf(ctx, dropletID) + } + if rf, ok := ret.Get(0).(func(context.Context, int) *godo.Droplet); ok { + r0 = rf(ctx, dropletID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Droplet) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, int) error); ok { + r1 = rf(ctx, dropletID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetKeyByFingerprint provides a mock function with given fields: ctx, fingerprint +func (_m *DoClient) GetKeyByFingerprint(ctx context.Context, fingerprint string) (*godo.Key, error) { + ret := _m.Called(ctx, fingerprint) + + if len(ret) == 0 { + panic("no return value specified for GetKeyByFingerprint") + } + + var r0 *godo.Key + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*godo.Key, error)); ok { + return rf(ctx, fingerprint) + } + if rf, ok := ret.Get(0).(func(context.Context, string) *godo.Key); ok { + r0 = rf(ctx, fingerprint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*godo.Key) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, fingerprint) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NewDoClient creates a new instance of DoClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDoClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DoClient { + mock := &DoClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/digitalocean/provider.go b/core/provider/digitalocean/provider.go new file mode 100644 index 0000000..2e0b696 --- /dev/null +++ b/core/provider/digitalocean/provider.go @@ -0,0 +1,413 @@ +package digitalocean + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + "github.com/docker/docker/api/types/image" + "github.com/skip-mev/petri/core/v3/provider/clients" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/client" + + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/util" +) + +var _ provider.ProviderI = (*Provider)(nil) + +const ( + providerLabelName = "petri-provider" + dockerPort = "2375" +) + +type ProviderState struct { + TaskStates map[string]*TaskState `json:"task_states"` // map of task ids to the corresponding task state + Name string `json:"name"` + PetriTag string `json:"petri_tag"` + UserIPs []string `json:"user_ips"` + SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` + FirewallID string `json:"firewall_id"` +} + +type Provider struct { + state *ProviderState + stateMu sync.Mutex + + logger *zap.Logger + doClient DoClient + dockerClients map[string]clients.DockerClient // map of droplet ip address to docker clients +} + +// NewProvider creates a provider that implements the Provider interface for DigitalOcean. +// Token is the DigitalOcean API token +func NewProvider(ctx context.Context, logger *zap.Logger, providerName string, token string, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + doClient := NewGodoClient(token) + return NewProviderWithClient(ctx, logger, providerName, doClient, nil, additionalUserIPS, sshKeyPair) +} + +// NewProviderWithClient creates a provider with custom digitalocean/docker client implementation. +// This is primarily used for testing. +func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName string, doClient DoClient, dockerClients map[string]clients.DockerClient, additionalUserIPS []string, sshKeyPair *SSHKeyPair) (*Provider, error) { + var err error + if sshKeyPair == nil { + sshKeyPair, err = MakeSSHKeyPair() + if err != nil { + return nil, err + } + } + + userIPs, err := getUserIPs(ctx) + if err != nil { + return nil, err + } + + userIPs = append(userIPs, additionalUserIPS...) + + if dockerClients == nil { + dockerClients = make(map[string]clients.DockerClient) + } + + petriTag := fmt.Sprintf("petri-droplet-%s", util.RandomString(5)) + digitalOceanProvider := &Provider{ + logger: logger.Named("digitalocean_provider"), + doClient: doClient, + dockerClients: dockerClients, + state: &ProviderState{ + TaskStates: make(map[string]*TaskState), + UserIPs: userIPs, + Name: providerName, + SSHKeyPair: sshKeyPair, + PetriTag: petriTag, + }, + } + + _, err = digitalOceanProvider.createTag(ctx, petriTag) + if err != nil { + return nil, err + } + + firewall, err := digitalOceanProvider.createFirewall(ctx, userIPs) + if err != nil { + return nil, fmt.Errorf("failed to create firewall: %w", err) + } + + digitalOceanProvider.state.FirewallID = firewall.ID + + //TODO(Zygimantass): TOCTOU issue + if key, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil { + _, err = digitalOceanProvider.createSSHKey(ctx, sshKeyPair.PublicKey) + if err != nil { + if !strings.Contains(err.Error(), "422") { + return nil, err + } + } + } + + return digitalOceanProvider, nil +} + +func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefinition) (provider.TaskI, error) { + if err := definition.ValidateBasic(); err != nil { + return nil, fmt.Errorf("failed to validate task definition: %w", err) + } + + if definition.ProviderSpecificConfig == nil { + return nil, fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) + } + + var doConfig DigitalOceanTaskConfig + doConfig, ok := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) + if !ok { + return nil, fmt.Errorf("invalid provider specific config type for %s", definition.Name) + } + + if err := doConfig.ValidateBasic(); err != nil { + return nil, fmt.Errorf("could not cast digitalocean specific config: %w", err) + } + + p.logger.Info("creating droplet", zap.String("name", definition.Name)) + + droplet, err := p.CreateDroplet(ctx, definition) + if err != nil { + return nil, err + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return nil, err + } + + p.logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) + + dockerClient := p.dockerClients[ip] + if dockerClient == nil { + dockerClient, err = clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + if err != nil { + return nil, err + } + } + + _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) + if err != nil { + p.logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) + if err = dockerClient.ImagePull(ctx, p.logger, definition.Image.Image, image.PullOptions{}); err != nil { + return nil, err + } + } + + state := p.GetState() + + err = util.WaitForCondition(ctx, 30*time.Second, 1*time.Second, func() (bool, error) { + _, err := dockerClient.ContainerCreate(ctx, &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Labels: map[string]string{ + providerLabelName: state.Name, + }, + Env: convertEnvMapToList(definition.Environment), + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: definition.DataDir, + }, + }, + NetworkMode: container.NetworkMode("host"), + }, nil, nil, definition.ContainerName) + + if err != nil { + if client.IsErrConnectionFailed(err) { + p.logger.Warn("connection failed while creating container, will retry", zap.Error(err)) + return false, nil + } + return false, err + } + + return true, nil + }) + + if err != nil { + return nil, fmt.Errorf("failed to create container after retries: %w", err) + } + + taskState := &TaskState{ + ID: strconv.Itoa(droplet.ID), + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + ProviderName: state.Name, + SSHKeyPair: state.SSHKeyPair, + } + + p.stateMu.Lock() + defer p.stateMu.Unlock() + + p.state.TaskStates[taskState.ID] = taskState + + return &Task{ + state: taskState, + removeTask: p.removeTask, + logger: p.logger.With(zap.String("task", definition.Name)), + doClient: p.doClient, + dockerClient: dockerClient, + }, nil +} + +func (p *Provider) SerializeProvider(context.Context) ([]byte, error) { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + bz, err := json.Marshal(p.state) + + return bz, err +} + +func RestoreProvider(ctx context.Context, token string, state []byte, doClient DoClient, dockerClients map[string]clients.DockerClient) (*Provider, error) { + if doClient == nil && token == "" { + return nil, errors.New("a valid token or digital ocean client must be passed when restoring the provider") + } + var providerState ProviderState + + err := json.Unmarshal(state, &providerState) + if err != nil { + return nil, err + } + + if dockerClients == nil { + dockerClients = make(map[string]clients.DockerClient) + } + + digitalOceanProvider := &Provider{ + state: &providerState, + dockerClients: dockerClients, + logger: zap.L().Named("digitalocean_provider"), + } + + if doClient != nil { + digitalOceanProvider.doClient = doClient + } else { + digitalOceanProvider.doClient = NewGodoClient(token) + } + + for _, taskState := range providerState.TaskStates { + id, err := strconv.Atoi(taskState.ID) + if err != nil { + return nil, err + } + + droplet, err := digitalOceanProvider.doClient.GetDroplet(ctx, id) + if err != nil { + return nil, fmt.Errorf("failed to get droplet for task state: %w", err) + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return nil, fmt.Errorf("failed to get droplet IP: %w", err) + } + + if digitalOceanProvider.dockerClients[ip] == nil { + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + if err != nil { + return nil, fmt.Errorf("failed to create docker client: %w", err) + } + digitalOceanProvider.dockerClients[ip] = dockerClient + } + } + + return digitalOceanProvider, nil +} + +func (p *Provider) SerializeTask(ctx context.Context, task provider.TaskI) ([]byte, error) { + if _, ok := task.(*Task); !ok { + return nil, fmt.Errorf("task is not a Docker task") + } + + doTask := task.(*Task) + + bz, err := json.Marshal(doTask.state) + + if err != nil { + return nil, err + } + + return bz, nil +} + +func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.TaskI, error) { + var taskState TaskState + + err := json.Unmarshal(bz, &taskState) + if err != nil { + return nil, err + } + + task := &Task{ + state: &taskState, + removeTask: p.removeTask, + } + + if err := p.initializeDeserializedTask(ctx, task); err != nil { + return nil, err + } + + return task, nil +} + +func (p *Provider) initializeDeserializedTask(ctx context.Context, task *Task) error { + taskState := task.GetState() + task.logger = p.logger.With(zap.String("task", taskState.Name)) + task.doClient = p.doClient + + droplet, err := task.getDroplet(ctx) + if err != nil { + return fmt.Errorf("failed to get droplet for task initialization: %w", err) + } + + ip, err := droplet.PublicIPv4() + if err != nil { + return fmt.Errorf("failed to get droplet IP: %w", err) + } + + if p.dockerClients[ip] == nil { + dockerClient, err := clients.NewDockerClient(fmt.Sprintf("tcp://%s:%s", ip, dockerPort)) + if err != nil { + return fmt.Errorf("failed to create docker client: %w", err) + } + p.dockerClients[ip] = dockerClient + } + + task.dockerClient = p.dockerClients[ip] + return nil +} + +func (p *Provider) Teardown(ctx context.Context) error { + p.logger.Info("tearing down DigitalOcean provider") + + if err := p.teardownTasks(ctx); err != nil { + return err + } + if err := p.teardownFirewall(ctx); err != nil { + return err + } + if err := p.teardownSSHKey(ctx); err != nil { + return err + } + if err := p.teardownTag(ctx); err != nil { + return err + } + return nil +} + +func (p *Provider) teardownTasks(ctx context.Context) error { + return p.doClient.DeleteDropletByTag(ctx, p.GetState().PetriTag) +} + +func (p *Provider) teardownFirewall(ctx context.Context) error { + return p.doClient.DeleteFirewall(ctx, p.GetState().FirewallID) +} + +func (p *Provider) teardownSSHKey(ctx context.Context) error { + return p.doClient.DeleteKeyByFingerprint(ctx, p.GetState().SSHKeyPair.Fingerprint) +} + +func (p *Provider) teardownTag(ctx context.Context) error { + return p.doClient.DeleteTag(ctx, p.GetState().PetriTag) +} + +func (p *Provider) removeTask(_ context.Context, taskID string) error { + p.stateMu.Lock() + defer p.stateMu.Unlock() + + delete(p.state.TaskStates, taskID) + + return nil +} + +func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, error) { + req := &godo.TagCreateRequest{ + Name: tagName, + } + + return p.doClient.CreateTag(ctx, req) +} + +func (p *Provider) GetState() ProviderState { + p.stateMu.Lock() + defer p.stateMu.Unlock() + return *p.state +} diff --git a/core/provider/digitalocean/provider_test.go b/core/provider/digitalocean/provider_test.go new file mode 100644 index 0000000..baa8815 --- /dev/null +++ b/core/provider/digitalocean/provider_test.go @@ -0,0 +1,564 @@ +package digitalocean + +import ( + "context" + "fmt" + "sync" + "testing" + "time" + + "github.com/skip-mev/petri/core/v3/provider/clients" + + "github.com/skip-mev/petri/core/v3/provider/digitalocean/mocks" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + specs "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" + + "github.com/skip-mev/petri/core/v3/provider" + dockerMocks "github.com/skip-mev/petri/core/v3/provider/mocks" + "github.com/skip-mev/petri/core/v3/util" +) + +func setupTestProvider(t *testing.T, ctx context.Context) (*Provider, *mocks.DoClient, *dockerMocks.DockerClient) { + logger := zap.NewExample() + mockDO := mocks.NewDoClient(t) + mockDocker := dockerMocks.NewDockerClient(t) + + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil) + mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "ubuntu:latest", image.PullOptions{}).Return(nil) + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "ubuntu:latest", + Entrypoint: []string{"/bin/bash"}, + Cmd: []string{"-c", "echo hello"}, + Env: []string{"TEST=value"}, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "/data", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), "test-container").Return(container.CreateResponse{ID: "test-container"}, nil) + mockDocker.On("Close").Return(nil) + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) + + mockDockerClients := map[string]clients.DockerClient{ + "10.0.0.1": mockDocker, + } + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + droplet := &godo.Droplet{ + ID: 123, + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + Status: "active", + } + + var callCount int + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(func(ctx context.Context, id int) *godo.Droplet { + if callCount == 0 { + callCount++ + return &godo.Droplet{ + ID: id, + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + Status: "new", + } + } + return droplet + }, func(ctx context.Context, id int) error { + return nil + }).Maybe() + + mockDO.On("DeleteDropletByID", ctx, droplet.ID).Return(nil).Maybe() + + return p, mockDO, mockDocker +} + +func TestCreateTask_ValidTask(t *testing.T) { + ctx := context.Background() + p, _, _ := setupTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + assert.NoError(t, err) + assert.Equal(t, task.GetDefinition(), taskDef) + assert.NotNil(t, task) + + err = task.Destroy(ctx) + assert.NoError(t, err) +} + +func setupValidationTestProvider(t *testing.T, ctx context.Context) *Provider { + logger := zap.NewExample() + mockDO := mocks.NewDoClient(t) + mockDocker := dockerMocks.NewDockerClient(t) + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) + + mockDockerClients := map[string]clients.DockerClient{ + "10.0.0.1": mockDocker, + } + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + return p +} + +func TestCreateTask_MissingProviderConfig(t *testing.T) { + ctx := context.Background() + p := setupValidationTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: nil, + } + + task, err := p.CreateTask(ctx, taskDef) + assert.Error(t, err) + assert.Nil(t, task) +} + +func TestCreateTask_MissingRegion(t *testing.T) { + ctx := context.Background() + p := setupValidationTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + assert.Error(t, err) + assert.Nil(t, task) +} + +func TestSerializeAndRestoreTask(t *testing.T) { + ctx := context.Background() + p, mockDO, mockDocker := setupTestProvider(t, ctx) + + taskDef := provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{Image: "ubuntu:latest", UID: "1000", GID: "1000"}, + Entrypoint: []string{"/bin/bash"}, + Command: []string{"-c", "echo hello"}, + Environment: map[string]string{"TEST": "value"}, + DataDir: "/data", + ContainerName: "test-container", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + } + + task, err := p.CreateTask(ctx, taskDef) + require.NoError(t, err) + + taskData, err := p.SerializeTask(ctx, task) + assert.NoError(t, err) + assert.NotNil(t, taskData) + + mockDO.On("GetDroplet", ctx, 123).Return(&godo.Droplet{ + ID: 123, + Name: "test-droplet", + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + }, nil) + + deserializedTask, err := p.DeserializeTask(ctx, taskData) + assert.NoError(t, err) + assert.NotNil(t, deserializedTask) + + t1 := task.(*Task) + t2 := deserializedTask.(*Task) + t1State := t1.GetState() + t2State := t2.GetState() + + if configMap, ok := t2State.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { + doConfig := make(DigitalOceanTaskConfig) + for k, v := range configMap { + doConfig[k] = v.(string) + } + t2State.Definition.ProviderSpecificConfig = doConfig + } + + assert.Equal(t, t1State, t2State) + assert.NotNil(t, t2.logger) + assert.NotNil(t, t2State.SSHKeyPair) + assert.NotNil(t, t2.doClient) + assert.NotNil(t, t2.dockerClient) + + err = t2.Destroy(ctx) + assert.NoError(t, err) + + mockDO.AssertExpectations(t) + mockDocker.AssertExpectations(t) +} + +func TestConcurrentTaskCreationAndCleanup(t *testing.T) { + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + logger, _ := zap.NewDevelopment() + mockDockerClients := make(map[string]clients.DockerClient) + mockDO := mocks.NewDoClient(t) + + for i := 0; i < 10; i++ { + ip := fmt.Sprintf("10.0.0.%d", i+1) + mockDocker := dockerMocks.NewDockerClient(t) + mockDockerClients[ip] = mockDocker + + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() + mockDocker.On("ImageInspectWithRaw", ctx, "nginx:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")).Once() + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "nginx:latest", image.PullOptions{}).Return(nil).Once() + mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { + return config.Image == "nginx:latest" + }), mock.Anything, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.AnythingOfType("string")).Return(container.CreateResponse{ID: fmt.Sprintf("container-%d", i)}, nil).Once() + mockDocker.On("ContainerStart", ctx, fmt.Sprintf("container-%d", i), container.StartOptions{}).Return(nil).Once() + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ + { + ID: fmt.Sprintf("container-%d", i), + State: "running", + }, + }, nil).Times(3) + mockDocker.On("ContainerInspect", ctx, fmt.Sprintf("container-%d", i)).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil).Maybe() + mockDocker.On("Close").Return(nil).Once() + } + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "test-tag"}, nil) + + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")). + Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) + + p, err := NewProviderWithClient(ctx, logger, "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + numTasks := 10 + var wg sync.WaitGroup + errors := make(chan error, numTasks) + tasks := make(chan *Task, numTasks) + taskMutex := sync.Mutex{} + dropletIDs := make(map[string]bool) + ipAddresses := make(map[string]bool) + + for i := 0; i < numTasks; i++ { + dropletID := 1000 + i + droplet := &godo.Droplet{ + ID: dropletID, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: fmt.Sprintf("10.0.0.%d", i+1), + }, + }, + }, + } + + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil).Once() + // we cant predict how many times GetDroplet will be called exactly as the provider polls waiting for its creation + mockDO.On("GetDroplet", ctx, dropletID).Return(droplet, nil).Maybe() + mockDO.On("DeleteDropletByID", ctx, dropletID).Return(nil).Once() + } + + mockDO.On("DeleteDropletByTag", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteFirewall", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil).Once() + mockDO.On("DeleteTag", ctx, mock.AnythingOfType("string")).Return(nil).Once() + + for i := 0; i < numTasks; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + + task, err := p.CreateTask(ctx, provider.TaskDefinition{ + Name: fmt.Sprintf("test-task-%d", index), + ContainerName: fmt.Sprintf("test-container-%d", index), + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456789", + }, + }) + + if err != nil { + errors <- fmt.Errorf("task creation error: %v", err) + return + } + + if err := task.Start(ctx); err != nil { + errors <- fmt.Errorf("task start error: %v", err) + return + } + + taskMutex.Lock() + doTask := task.(*Task) + state := doTask.GetState() + + if dropletIDs[state.ID] { + errors <- fmt.Errorf("duplicate droplet ID found: %s", state.ID) + } + dropletIDs[state.ID] = true + + ip, err := task.GetIP(ctx) + if err == nil { + if ipAddresses[ip] { + errors <- fmt.Errorf("duplicate IP address found: %s", ip) + } + ipAddresses[ip] = true + } + + tasks <- doTask + taskMutex.Unlock() + }(i) + } + + wg.Wait() + close(errors) + + for err := range errors { + require.NoError(t, err) + } + + require.Equal(t, numTasks, len(p.GetState().TaskStates), "Provider state should contain all tasks") + + var tasksToCleanup []*Task + close(tasks) + for task := range tasks { + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status, "All tasks should be in running state") + + state := task.GetState() + require.NotEmpty(t, state.ID, "Task should have a droplet ID") + require.NotEmpty(t, state.Name, "Task should have a name") + require.NotEmpty(t, state.Definition.ContainerName, "Task should have a container name") + tasksToCleanup = append(tasksToCleanup, task) + } + + var cleanupWg sync.WaitGroup + cleanupErrors := make(chan error, numTasks) + + for _, task := range tasksToCleanup { + cleanupWg.Add(1) + go func(t *Task) { + defer cleanupWg.Done() + if err := t.Destroy(ctx); err != nil { + cleanupErrors <- fmt.Errorf("cleanup error: %v", err) + return + } + if err != nil { + cleanupErrors <- fmt.Errorf("task state cleanup error: %v", err) + } + }(task) + } + + cleanupWg.Wait() + close(cleanupErrors) + + for err := range cleanupErrors { + require.NoError(t, err) + } + + err = util.WaitForCondition(ctx, 30*time.Second, 100*time.Millisecond, func() (bool, error) { + return len(p.GetState().TaskStates) == 0, nil + }) + require.NoError(t, err, "Provider state should be empty after cleanup") + + err = p.Teardown(ctx) + require.NoError(t, err) + + mockDO.AssertExpectations(t) + for _, client := range mockDockerClients { + client.(*dockerMocks.DockerClient).AssertExpectations(t) + } +} + +func TestProviderSerialization(t *testing.T) { + ctx := context.Background() + mockDO := mocks.NewDoClient(t) + mockDocker := dockerMocks.NewDockerClient(t) + + mockDO.On("CreateTag", ctx, mock.Anything).Return(&godo.Tag{Name: "petri-droplet-test"}, nil) + mockDO.On("CreateFirewall", ctx, mock.Anything).Return(&godo.Firewall{ID: "test-firewall"}, nil) + mockDO.On("GetKeyByFingerprint", ctx, mock.AnythingOfType("string")).Return(nil, nil) + mockDO.On("CreateKey", ctx, mock.Anything).Return(&godo.Key{}, nil) + + mockDockerClients := map[string]clients.DockerClient{ + "10.0.0.1": mockDocker, + } + + p1, err := NewProviderWithClient(ctx, zap.NewExample(), "test-provider", mockDO, mockDockerClients, []string{}, nil) + require.NoError(t, err) + + droplet := &godo.Droplet{ + ID: 123, + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + IPAddress: "10.0.0.1", + Type: "public", + }, + }, + }, + Status: "active", + } + + mockDO.On("CreateDroplet", ctx, mock.Anything).Return(droplet, nil) + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() + + mockDocker.On("Ping", ctx).Return(types.Ping{}, nil).Once() + mockDocker.On("ImageInspectWithRaw", ctx, "ubuntu:latest").Return(types.ImageInspect{}, []byte{}, fmt.Errorf("image not found")) + mockDocker.On("ImagePull", ctx, mock.AnythingOfType("*zap.Logger"), "ubuntu:latest", image.PullOptions{}).Return(nil) + mockDocker.On("ContainerCreate", ctx, mock.MatchedBy(func(config *container.Config) bool { + return config.Image == "ubuntu:latest" && + config.Hostname == "test-task" && + len(config.Labels) > 0 && + config.Labels[providerLabelName] == "test-provider" + }), mock.MatchedBy(func(hostConfig *container.HostConfig) bool { + return len(hostConfig.Mounts) == 1 && + hostConfig.Mounts[0].Target == "/data" && + hostConfig.NetworkMode == "host" + }), mock.Anything, mock.Anything, mock.Anything).Return(container.CreateResponse{ID: "test-container"}, nil) + + _, err = p1.CreateTask(ctx, provider.TaskDefinition{ + Name: "test-task", + ContainerName: "test-container", + Image: provider.ImageDefinition{ + Image: "ubuntu:latest", + UID: "1000", + GID: "1000", + }, + DataDir: "/data", + ProviderSpecificConfig: DigitalOceanTaskConfig{ + "size": "s-1vcpu-1gb", + "region": "nyc1", + "image_id": "123456", + }, + }) + require.NoError(t, err) + + state1 := p1.GetState() + serialized, err := p1.SerializeProvider(ctx) + require.NoError(t, err) + + mockDO2 := mocks.NewDoClient(t) + mockDO2.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil).Maybe() + + mockDocker2 := dockerMocks.NewDockerClient(t) + mockDocker2.On("Ping", ctx).Return(types.Ping{}, nil).Maybe() + + mockDockerClients2 := map[string]clients.DockerClient{ + "10.0.0.1": mockDocker2, + } + + p2, err := RestoreProvider(ctx, "test-token", serialized, mockDO2, mockDockerClients2) + require.NoError(t, err) + + state2 := p2.GetState() + assert.Equal(t, state1.Name, state2.Name) + assert.Equal(t, state1.PetriTag, state2.PetriTag) + assert.Equal(t, state1.FirewallID, state2.FirewallID) + assert.Equal(t, len(state1.TaskStates), len(state2.TaskStates)) + + for id, task1 := range state1.TaskStates { + task2, exists := state2.TaskStates[id] + assert.True(t, exists) + assert.Equal(t, task1.Name, task2.Name) + assert.Equal(t, task1.Status, task2.Status) + + if configMap, ok := task2.Definition.ProviderSpecificConfig.(map[string]interface{}); ok { + doConfig := make(DigitalOceanTaskConfig) + for k, v := range configMap { + doConfig[k] = v.(string) + } + task2.Definition.ProviderSpecificConfig = doConfig + } + assert.Equal(t, task1.Definition, task2.Definition) + } +} diff --git a/core/provider/digitalocean/ssh.go b/core/provider/digitalocean/ssh.go index dbcdc0a..57af65f 100644 --- a/core/provider/digitalocean/ssh.go +++ b/core/provider/digitalocean/ssh.go @@ -96,16 +96,6 @@ func getUserIPs(ctx context.Context) (ips []string, err error) { } func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) { - req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.petriTag)} - - key, res, err := p.doClient.Keys.Create(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return key, nil + req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.GetState().PetriTag)} + return p.doClient.CreateKey(ctx, req) } diff --git a/core/provider/digitalocean/tag.go b/core/provider/digitalocean/tag.go deleted file mode 100644 index db1e9fa..0000000 --- a/core/provider/digitalocean/tag.go +++ /dev/null @@ -1,25 +0,0 @@ -package digitalocean - -import ( - "context" - "fmt" - - "github.com/digitalocean/godo" -) - -func (p *Provider) createTag(ctx context.Context, tagName string) (*godo.Tag, error) { - req := &godo.TagCreateRequest{ - Name: tagName, - } - - tag, res, err := p.doClient.Tags.Create(ctx, req) - if err != nil { - return nil, err - } - - if res.StatusCode > 299 || res.StatusCode < 200 { - return nil, fmt.Errorf("unexpected status code: %d", res.StatusCode) - } - - return tag, nil -} diff --git a/core/provider/digitalocean/task.go b/core/provider/digitalocean/task.go index ac7eb39..b3edeb6 100644 --- a/core/provider/digitalocean/task.go +++ b/core/provider/digitalocean/task.go @@ -4,13 +4,16 @@ import ( "bytes" "context" "fmt" - "io" "net" "path" + "sync" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + + "golang.org/x/crypto/ssh" + "github.com/docker/docker/api/types/container" - "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" dockerclient "github.com/docker/docker/client" "github.com/docker/docker/pkg/stdcopy" @@ -23,88 +26,30 @@ import ( "github.com/skip-mev/petri/core/v3/util" ) -func (p *Provider) CreateTask(ctx context.Context, logger *zap.Logger, definition provider.TaskDefinition) (string, error) { - if err := definition.ValidateBasic(); err != nil { - return "", fmt.Errorf("failed to validate task definition: %w", err) - } - - if definition.ProviderSpecificConfig == nil { - return "", fmt.Errorf("digitalocean specific config is nil for %s", definition.Name) - } - - doConfig := definition.ProviderSpecificConfig.(DigitalOceanTaskConfig) - - if err := doConfig.ValidateBasic(); err != nil { - return "", fmt.Errorf("could not cast digitalocean specific config: %w", err) - } - - logger = logger.Named("digitalocean_provider") - - logger.Info("creating droplet", zap.String("name", definition.Name)) - - droplet, err := p.CreateDroplet(ctx, definition) - if err != nil { - return "", err - } - - ip, err := p.GetIP(ctx, droplet.Name) - if err != nil { - return "", err - } - - logger.Info("droplet created", zap.String("name", droplet.Name), zap.String("ip", ip)) - - dockerClient, err := p.getDropletDockerClient(ctx, droplet.Name) - defer dockerClient.Close() // nolint - - if err != nil { - return "", err - } - - _, _, err = dockerClient.ImageInspectWithRaw(ctx, definition.Image.Image) - if err != nil { - logger.Info("image not found, pulling", zap.String("image", definition.Image.Image)) - if err := p.pullImage(ctx, dockerClient, definition.Image.Image); err != nil { - return "", err - } - } +type TaskState struct { + ID string `json:"id"` + Name string `json:"name"` + Definition provider.TaskDefinition `json:"definition"` + Status provider.TaskStatus `json:"status"` + ProviderName string `json:"provider_name"` + SSHKeyPair *SSHKeyPair `json:"ssh_key_pair"` +} - _, err = dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, - Tty: false, - Hostname: definition.Name, - Labels: map[string]string{ - providerLabelName: p.name, - }, - Env: convertEnvMapToList(definition.Environment), - }, &container.HostConfig{ - Mounts: []mount.Mount{ - { - Type: mount.TypeBind, - Source: "/docker_volumes", - Target: definition.DataDir, - }, - }, - NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) - if err != nil { - return "", err - } +type Task struct { + state *TaskState + stateMu sync.Mutex - return droplet.Name, nil + removeTask provider.RemoveTaskFunc + logger *zap.Logger + sshClient *ssh.Client + doClient DoClient + dockerClient clients.DockerClient } -func (p *Provider) StartTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint +var _ provider.TaskI = (*Task)(nil) - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Start(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -113,40 +58,39 @@ func (p *Provider) StartTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.GetState().Name) } containerID := containers[0].ID - err = dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) + err = t.dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}) if err != nil { return err } err = util.WaitForCondition(ctx, time.Second*300, time.Millisecond*100, func() (bool, error) { - status, err := p.GetTaskStatus(ctx, taskName) + status, err := t.GetStatus(ctx) if err != nil { return false, err } - if status == provider.TASK_RUNNING { - return true, nil + if status != provider.TASK_RUNNING { + return false, nil } - return false, nil + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_RUNNING + return true, nil }) + t.logger.Info("final task status after start", zap.Any("status", t.GetState().Status)) return err } -func (p *Provider) StopTask(ctx context.Context, taskName string) error { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return err - } - - defer dockerClient.Close() // nolint - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ +func (t *Task) Stop(ctx context.Context) error { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -155,30 +99,52 @@ func (p *Provider) StopTask(ctx context.Context, taskName string) error { } if len(containers) != 1 { - return fmt.Errorf("could not find container for %s", taskName) + return fmt.Errorf("could not find container for %s", t.GetState().Name) } - return dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) + t.stateMu.Lock() + defer t.stateMu.Unlock() + + t.state.Status = provider.TASK_STOPPED + return t.dockerClient.ContainerStop(ctx, containers[0].ID, container.StopOptions{}) } -func (p *Provider) ModifyTask(ctx context.Context, taskName string, definition provider.TaskDefinition) error { - return nil +func (t *Task) Initialize(ctx context.Context) error { + panic("implement me") +} + +func (t *Task) Modify(ctx context.Context, definition provider.TaskDefinition) error { + panic("implement me") } -func (p *Provider) DestroyTask(ctx context.Context, taskName string) error { - logger := p.logger.With(zap.String("task", taskName)) +func (t *Task) Destroy(ctx context.Context) error { + logger := t.logger.With(zap.String("task", t.GetState().Name)) logger.Info("deleting task") + defer t.dockerClient.Close() - err := p.deleteDroplet(ctx, taskName) + err := t.deleteDroplet(ctx) if err != nil { return err } + if err := t.removeTask(ctx, t.GetState().ID); err != nil { + return err + } return nil } -func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider.TaskStatus, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetState() TaskState { + t.stateMu.Lock() + defer t.stateMu.Unlock() + return *t.state +} + +func (t *Task) GetDefinition() provider.TaskDefinition { + return t.GetState().Definition +} + +func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -187,14 +153,7 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STOPPED, nil } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - return provider.TASK_STATUS_UNDEFINED, err - } - - defer dockerClient.Close() - - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -203,25 +162,25 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider } if len(containers) != 1 { - return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", taskName) + return provider.TASK_STATUS_UNDEFINED, fmt.Errorf("could not find container for %s", t.GetState().Name) } - container, err := dockerClient.ContainerInspect(ctx, containers[0].ID) + c, err := t.dockerClient.ContainerInspect(ctx, containers[0].ID) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } - switch state := container.State.Status; state { + switch state := c.State.Status; state { case "created": return provider.TASK_STOPPED, nil case "running": return provider.TASK_RUNNING, nil case "paused": return provider.TASK_PAUSED, nil - case "restarting": - return provider.TASK_RUNNING, nil // todo(zygimantass): is this sane? case "removing": return provider.TASK_STOPPED, nil + case "restarting": + return provider.TASK_RESTARTING, nil case "exited": return provider.TASK_STOPPED, nil case "dead": @@ -231,10 +190,10 @@ func (p *Provider) GetTaskStatus(ctx context.Context, taskName string) (provider return provider.TASK_STATUS_UNDEFINED, nil } -func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath string, content []byte) error { +func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.GetState().Name) if err != nil { return err } @@ -266,10 +225,10 @@ func (p *Provider) WriteFile(ctx context.Context, taskName string, relPath strin return nil } -func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string) ([]byte, error) { +func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { absPath := path.Join("/docker_volumes", relPath) - sshClient, err := p.getDropletSSHClient(ctx, taskName) + sshClient, err := t.getDropletSSHClient(ctx, t.GetState().Name) if err != nil { return nil, err } @@ -293,12 +252,12 @@ func (p *Provider) ReadFile(ctx context.Context, taskName string, relPath string return content, nil } -func (p *Provider) DownloadDir(ctx context.Context, s string, s2 string, s3 string) error { +func (t *Task) DownloadDir(ctx context.Context, s string, s2 string) error { panic("implement me") } -func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { - droplet, err := p.getDroplet(ctx, taskName) +func (t *Task) GetIP(ctx context.Context) (string, error) { + droplet, err := t.getDroplet(ctx) if err != nil { return "", err @@ -307,8 +266,8 @@ func (p *Provider) GetIP(ctx context.Context, taskName string) (string, error) { return droplet.PublicIPv4() } -func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port string) (string, error) { - ip, err := p.GetIP(ctx, taskName) +func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { + ip, err := t.GetIP(ctx) if err != nil { return "", err } @@ -316,14 +275,49 @@ func (p *Provider) GetExternalAddress(ctx context.Context, taskName string, port return net.JoinHostPort(ip, port), nil } -func (p *Provider) RunCommand(ctx context.Context, taskName string, command []string) (string, string, int, error) { - dockerClient, err := p.getDropletDockerClient(ctx, taskName) +func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, int, error) { + status, err := t.GetStatus(ctx) if err != nil { return "", "", 0, err } - defer dockerClient.Close() - containers, err := dockerClient.ContainerList(ctx, container.ListOptions{ + if status != provider.TASK_RUNNING { + return t.runCommandWhileStopped(ctx, cmd) + } + + return t.runCommand(ctx, cmd) +} + +func waitForExec(ctx context.Context, dockerClient clients.DockerClient, execID string) (int, error) { + lastExitCode := 0 + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + +loop: + for { + select { + case <-ctx.Done(): + return lastExitCode, ctx.Err() + case <-ticker.C: + execInspect, err := dockerClient.ContainerExecInspect(ctx, execID) + if err != nil { + return lastExitCode, err + } + + if execInspect.Running { + continue + } + + lastExitCode = execInspect.ExitCode + break loop + } + } + + return lastExitCode, nil +} + +func (t *Task) runCommand(ctx context.Context, command []string) (string, string, int, error) { + containers, err := t.dockerClient.ContainerList(ctx, container.ListOptions{ Limit: 1, }) @@ -332,14 +326,14 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st } if len(containers) != 1 { - return "", "", 0, fmt.Errorf("could not find container for %s", taskName) + return "", "", 0, fmt.Errorf("could not find container for %s", t.GetState().Name) } id := containers[0].ID - p.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) + t.logger.Debug("running command", zap.String("id", id), zap.Strings("command", command)) - exec, err := dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: command, @@ -348,159 +342,113 @@ func (p *Provider) RunCommand(ctx context.Context, taskName string, command []st return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { return "", "", 0, err } defer resp.Close() - lastExitCode := 0 - - ticker := time.NewTicker(100 * time.Millisecond) - -loop: - for { - select { - case <-ctx.Done(): - return "", "", lastExitCode, ctx.Err() - case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) - if err != nil { - return "", "", lastExitCode, err - } - - if execInspect.Running { - continue - } - - lastExitCode = execInspect.ExitCode - - break loop - } - } - var stdout, stderr bytes.Buffer - _, err = stdcopy.StdCopy(&stdout, &stderr, resp.Reader) if err != nil { - return "", "", lastExitCode, err + return "", "", 0, err + } + + exitCode, err := waitForExec(ctx, t.dockerClient, exec.ID) + if err != nil { + return stdout.String(), stderr.String(), exitCode, err } - return stdout.String(), stderr.String(), lastExitCode, nil + return stdout.String(), stderr.String(), exitCode, nil } -func (p *Provider) RunCommandWhileStopped(ctx context.Context, taskName string, definition provider.TaskDefinition, command []string) (string, string, int, error) { - if err := definition.ValidateBasic(); err != nil { +func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { + state := t.GetState() + if err := state.Definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - dockerClient, err := p.getDropletDockerClient(ctx, taskName) - if err != nil { - p.logger.Error("failed to get docker client", zap.Error(err), zap.String("taskName", taskName)) - return "", "", 0, err - } - - definition.Entrypoint = []string{"sh", "-c"} - definition.Command = []string{"sleep 36000"} - definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", definition.Name, util.RandomString(5), time.Now().Unix()) - definition.Ports = []string{} - - createdContainer, err := dockerClient.ContainerCreate(ctx, &container.Config{ - Image: definition.Image.Image, - Entrypoint: definition.Entrypoint, - Cmd: definition.Command, + containerName := fmt.Sprintf("%s-executor-%s-%d", state.Definition.Name, util.RandomString(5), time.Now().Unix()) + createdContainer, err := t.dockerClient.ContainerCreate(ctx, &container.Config{ + Image: state.Definition.Image.Image, + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, Tty: false, - Hostname: definition.Name, + Hostname: state.Definition.Name, Labels: map[string]string{ - providerLabelName: p.name, + providerLabelName: state.ProviderName, }, - Env: convertEnvMapToList(definition.Environment), + Env: convertEnvMapToList(state.Definition.Environment), }, &container.HostConfig{ Mounts: []mount.Mount{ { Type: mount.TypeBind, Source: "/docker_volumes", - Target: definition.DataDir, + Target: state.Definition.DataDir, }, }, NetworkMode: container.NetworkMode("host"), - }, nil, nil, definition.ContainerName) + }, nil, nil, containerName) if err != nil { - p.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create container", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } + t.logger.Debug("container created successfully", zap.String("id", createdContainer.ID), zap.String("taskName", state.Name)) + defer func() { - if _, err := dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { - // auto-removed, but not detected as autoremoved + if _, err := t.dockerClient.ContainerInspect(ctx, createdContainer.ID); err != nil && dockerclient.IsErrNotFound(err) { + // container was auto-removed, no need to remove it again return } - if err := dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { - p.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", taskName), zap.String("id", createdContainer.ID)) + if err := t.dockerClient.ContainerRemove(ctx, createdContainer.ID, container.RemoveOptions{Force: true}); err != nil { + t.logger.Error("failed to remove container", zap.Error(err), zap.String("taskName", state.Name), zap.String("id", createdContainer.ID)) } }() - if err := startContainerWithBlock(ctx, dockerClient, createdContainer.ID); err != nil { - p.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", taskName)) + if err := startContainerWithBlock(ctx, t.dockerClient, createdContainer.ID); err != nil { + t.logger.Error("failed to start container", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } + t.logger.Debug("container started successfully", zap.String("id", createdContainer.ID), zap.String("taskName", state.Name)) + // wait for container start - exec, err := dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, createdContainer.ID, container.ExecOptions{ AttachStdout: true, AttachStderr: true, - Cmd: command, + Cmd: cmd, }) if err != nil { - p.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to create exec", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } - resp, err := dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { - p.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", taskName)) + t.logger.Error("failed to attach to exec", zap.Error(err), zap.String("taskName", state.Name)) return "", "", 0, err } defer resp.Close() - lastExitCode := 0 - - ticker := time.NewTicker(100 * time.Millisecond) - -loop: - for { - select { - case <-ctx.Done(): - return "", "", lastExitCode, ctx.Err() - case <-ticker.C: - execInspect, err := dockerClient.ContainerExecInspect(ctx, exec.ID) - if err != nil { - return "", "", lastExitCode, err - } - - if execInspect.Running { - continue - } - - lastExitCode = execInspect.ExitCode - - break loop - } - } - var stdout, stderr bytes.Buffer _, err = stdcopy.StdCopy(&stdout, &stderr, resp.Reader) if err != nil { return "", "", 0, err } - return stdout.String(), stderr.String(), lastExitCode, err + exitCode, err := waitForExec(ctx, t.dockerClient, exec.ID) + if err != nil { + return stdout.String(), stderr.String(), exitCode, err + } + + return stdout.String(), stderr.String(), exitCode, nil } -func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Client, containerID string) error { +func startContainerWithBlock(ctx context.Context, dockerClient clients.DockerClient, containerID string) error { // start container if err := dockerClient.ContainerStart(ctx, containerID, container.StartOptions{}); err != nil { return err @@ -531,19 +479,3 @@ func startContainerWithBlock(ctx context.Context, dockerClient *dockerclient.Cli } } } - -func (p *Provider) pullImage(ctx context.Context, dockerClient *dockerclient.Client, img string) error { - p.logger.Info("pulling image", zap.String("image", img)) - resp, err := dockerClient.ImagePull(ctx, img, image.PullOptions{}) - if err != nil { - return err - } - - defer resp.Close() - // throw away the image pull stdout response - _, err = io.Copy(io.Discard, resp) - if err != nil { - return err - } - return nil -} diff --git a/core/provider/digitalocean/task_test.go b/core/provider/digitalocean/task_test.go new file mode 100644 index 0000000..07bc58d --- /dev/null +++ b/core/provider/digitalocean/task_test.go @@ -0,0 +1,975 @@ +package digitalocean + +import ( + "bufio" + "bytes" + "context" + "fmt" + "net" + "net/http" + "strconv" + "testing" + "time" + + "github.com/digitalocean/godo" + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/api/types/network" + dockerMocks "github.com/skip-mev/petri/core/v3/provider/mocks" + + specs "github.com/opencontainers/image-spec/specs-go/v1" + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/digitalocean/mocks" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +// mockConn implements net.Conn interface for testing +type mockConn struct { + *bytes.Buffer +} + +// mockNotFoundError implements errdefs.ErrNotFound interface for testing +type mockNotFoundError struct { + error +} + +func (e mockNotFoundError) NotFound() {} + +func (m mockConn) Close() error { return nil } +func (m mockConn) LocalAddr() net.Addr { return nil } +func (m mockConn) RemoteAddr() net.Addr { return nil } +func (m mockConn) SetDeadline(t time.Time) error { return nil } +func (m mockConn) SetReadDeadline(t time.Time) error { return nil } +func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } + +const ( + testContainerID = "test-container-id" +) + +var ( + testContainer = types.Container{ID: testContainerID} + testDroplet = &godo.Droplet{ID: 123, Status: "active"} +) + +func TestTaskLifecycle(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + mockDocker.On("ContainerStop", ctx, testContainerID, container.StopOptions{}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(droplet.ID), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + err = task.Stop(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_STOPPED, task.GetState().Status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskRunCommand(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + execID := "test-exec-id" + execCreateResp := types.IDResponse{ID: execID} + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", ctx, execID, container.ExecAttachOptions{}).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", ctx, execID).Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(1), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + // Start command assertions + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + err = task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskRunCommandWhileStopped(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Twice() + + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: false, + }, + }, + }, nil).Once() + + mockDocker.On("ContainerRemove", ctx, testContainerID, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(1), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + // Start command assertions + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + err = task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskGetIP(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + expectedIP := "1.2.3.4" + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: expectedIP, + }, + }, + }, + } + + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(droplet.ID), + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + ip, err := task.GetIP(ctx) + require.NoError(t, err) + require.Equal(t, expectedIP, ip) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("%s:80", expectedIP), externalAddr) + + mockDO.AssertExpectations(t) +} + +func TestTaskDestroy(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + mockDO.On("GetDroplet", ctx, testDroplet.ID).Return(testDroplet, nil) + mockDO.On("DeleteDropletByID", ctx, testDroplet.ID).Return(nil) + mockDocker.On("Close").Return(nil) + + provider := &Provider{ + state: &ProviderState{ + TaskStates: make(map[string]*TaskState), + }, + } + providerState := provider.GetState() + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(testDroplet.ID), + Name: "test-task", + ProviderName: "test-provider", + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + removeTask: func(ctx context.Context, taskID string) error { + delete(provider.state.TaskStates, taskID) + return nil + }, + } + + providerState.TaskStates[task.GetState().ID] = task.state + + err := task.Destroy(ctx) + require.NoError(t, err) + require.Empty(t, providerState.TaskStates) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestRunCommandWhileStoppedContainerCleanup(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ + {ID: testContainerID}, + }, nil).Once() + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil).Once() + + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container exists + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, nil).Once() + + // container should be removed since it exists + mockDocker.On("ContainerRemove", ctx, testContainerID, container.RemoveOptions{Force: true}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(1), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +// this tests the case where the docker container is auto removed before cleanup doesnt return an error +func TestRunCommandWhileStoppedContainerAutoRemoved(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + mockDO.On("GetDroplet", ctx, 1).Return(testDroplet, nil) + + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{ + {ID: testContainerID}, + }, nil).Once() + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil).Once() + + createResp := container.CreateResponse{ID: testContainerID} + mockDocker.On("ContainerCreate", ctx, &container.Config{ + Image: "nginx:latest", + Entrypoint: []string{"sh", "-c"}, + Cmd: []string{"sleep 36000"}, + Tty: false, + Hostname: "test-task", + Labels: map[string]string{ + providerLabelName: "test-provider", + }, + Env: []string{}, + }, &container.HostConfig{ + Mounts: []mount.Mount{ + { + Type: mount.TypeBind, + Source: "/docker_volumes", + Target: "", + }, + }, + NetworkMode: container.NetworkMode("host"), + }, (*network.NetworkingConfig)(nil), (*specs.Platform)(nil), mock.Anything).Return(createResp, nil) + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + // first ContainerInspect for startup check + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Running: true, + }, + }, + }, nil).Once() + + execCreateResp := types.IDResponse{ID: "test-exec-id"} + mockDocker.On("ContainerExecCreate", ctx, testContainerID, container.ExecOptions{ + AttachStdout: true, + AttachStderr: true, + Cmd: []string{"echo", "hello"}, + }).Return(execCreateResp, nil) + + conn := &mockConn{Buffer: bytes.NewBuffer([]byte{})} + mockDocker.On("ContainerExecAttach", ctx, "test-exec-id", container.ExecAttachOptions{}).Return(types.HijackedResponse{ + Conn: conn, + Reader: bufio.NewReader(conn), + }, nil) + mockDocker.On("ContainerExecInspect", ctx, "test-exec-id").Return(container.ExecInspect{ + ExitCode: 0, + Running: false, + }, nil) + + // second ContainerInspect for cleanup check - container not found, so ContainerRemove should not be called + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, mockNotFoundError{fmt.Errorf("Error: No such container: test-container-id")}).Once() + mockDocker.AssertNotCalled(t, "ContainerRemove", ctx, testContainerID, mock.Anything) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(1), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + ContainerName: "test-task-container", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + _, stderr, exitCode, err := task.RunCommand(ctx, []string{"echo", "hello"}) + require.NoError(t, err) + require.Equal(t, 0, exitCode) + require.Empty(t, stderr) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestTaskExposingPort(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + droplet := &godo.Droplet{ + ID: 123, + Status: "active", + Networks: &godo.Networks{ + V4: []godo.NetworkV4{ + { + Type: "public", + IPAddress: "1.2.3.4", + }, + }, + }, + } + + mockDO.On("GetDroplet", ctx, droplet.ID).Return(droplet, nil) + + testContainer := types.Container{ + ID: testContainerID, + Ports: []types.Port{ + { + PrivatePort: 80, + PublicPort: 80, + Type: "tcp", + }, + }, + } + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + + mockDocker.On("ContainerStart", ctx, testContainerID, container.StartOptions{}).Return(nil) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(droplet.ID), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + Image: provider.ImageDefinition{ + Image: "nginx:latest", + UID: "1000", + GID: "1000", + }, + Ports: []string{"80"}, + }, + Status: provider.TASK_STOPPED, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + err := task.Start(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, task.GetState().Status) + + externalAddr, err := task.GetExternalAddress(ctx, "80") + require.NoError(t, err) + require.Equal(t, "1.2.3.4:80", externalAddr) + + status, err := task.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + req, err := http.NewRequest("GET", fmt.Sprintf("http://%s", externalAddr), nil) + require.NoError(t, err) + require.NotEmpty(t, req) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) +} + +func TestGetStatus(t *testing.T) { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + testDropletActive := &godo.Droplet{ + ID: 123, + Status: "active", + } + testDropletOff := &godo.Droplet{ + ID: 123, + Status: "off", + } + + tests := []struct { + name string + dropletStatus string + containerState string + setupMocks func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) + expectedStatus provider.TaskStatus + expectError bool + }{ + { + name: "droplet not active", + dropletStatus: "off", + containerState: "", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletOff.ID).Return(testDropletOff, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container running", + dropletStatus: "active", + containerState: "running", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "running", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_RUNNING, + expectError: false, + }, + { + name: "container paused", + dropletStatus: "active", + containerState: "paused", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "paused", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_PAUSED, + expectError: false, + }, + { + name: "container stopped state", + dropletStatus: "active", + containerState: "exited", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "exited", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container removing", + dropletStatus: "active", + containerState: "removing", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "removing", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container dead", + dropletStatus: "active", + containerState: "dead", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "dead", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "container created", + dropletStatus: "active", + containerState: "created", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "created", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STOPPED, + expectError: false, + }, + { + name: "unknown container status", + dropletStatus: "active", + containerState: "unknown_status", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{ + ContainerJSONBase: &types.ContainerJSONBase{ + State: &types.ContainerState{ + Status: "unknown_status", + }, + }, + }, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: false, + }, + { + name: "no containers found", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{}, nil) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "container inspect error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return([]types.Container{testContainer}, nil) + mockDocker.On("ContainerInspect", ctx, testContainerID).Return(types.ContainerJSON{}, fmt.Errorf("inspect error")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "getDroplet error", + dropletStatus: "", + containerState: "", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, 123).Return(nil, fmt.Errorf("failed to get droplet")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + { + name: "containerList error", + dropletStatus: "active", + containerState: "", + setupMocks: func(mockDocker *dockerMocks.DockerClient, mockDO *mocks.DoClient) { + mockDO.On("GetDroplet", ctx, testDropletActive.ID).Return(testDropletActive, nil) + mockDocker.On("ContainerList", ctx, container.ListOptions{ + Limit: 1, + }).Return(nil, fmt.Errorf("failed to list containers")) + }, + expectedStatus: provider.TASK_STATUS_UNDEFINED, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockDocker := dockerMocks.NewDockerClient(t) + mockDO := mocks.NewDoClient(t) + + tt.setupMocks(mockDocker, mockDO) + + task := &Task{ + state: &TaskState{ + ID: strconv.Itoa(123), + Name: "test-task", + ProviderName: "test-provider", + Definition: provider.TaskDefinition{ + Name: "test-task", + }, + }, + logger: logger, + dockerClient: mockDocker, + doClient: mockDO, + } + + status, err := task.GetStatus(ctx) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.expectedStatus, status) + + mockDocker.AssertExpectations(t) + mockDO.AssertExpectations(t) + }) + } +} diff --git a/core/provider/docker/network.go b/core/provider/docker/network.go index f18d42c..2616e3b 100644 --- a/core/provider/docker/network.go +++ b/core/provider/docker/network.go @@ -15,11 +15,13 @@ import ( const providerLabelName = "petri-provider" func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { - p.logger.Info("creating network", zap.String("name", p.state.NetworkName)) + state := p.GetState() + + p.logger.Info("creating network", zap.String("name", state.NetworkName)) subnet1 := rand.Intn(255) subnet2 := rand.Intn(255) - networkResponse, err := p.dockerClient.NetworkCreate(ctx, p.state.NetworkName, network.CreateOptions{ + networkResponse, err := p.dockerClient.NetworkCreate(ctx, state.NetworkName, network.CreateOptions{ Scope: "local", Driver: "bridge", Options: map[string]string{ // https://docs.docker.com/engine/reference/commandline/network_create/#bridge-driver-options @@ -32,7 +34,7 @@ func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { Attachable: false, Ingress: false, Labels: map[string]string{ - providerLabelName: p.state.Name, + providerLabelName: state.Name, }, IPAM: &network.IPAM{ Driver: "default", @@ -59,33 +61,34 @@ func (p *Provider) initNetwork(ctx context.Context) (network.Inspect, error) { // ensureNetwork checks if the network exists and has the expected configuration. func (p *Provider) ensureNetwork(ctx context.Context) error { - network, err := p.dockerClient.NetworkInspect(ctx, p.state.NetworkID, network.InspectOptions{}) + state := p.GetState() + network, err := p.dockerClient.NetworkInspect(ctx, state.NetworkID, network.InspectOptions{}) if err != nil { return err } - if network.ID != p.state.NetworkID { - return fmt.Errorf("network ID mismatch: %s != %s", network.ID, p.state.NetworkID) + if network.ID != state.NetworkID { + return fmt.Errorf("network ID mismatch: %s != %s", network.ID, state.NetworkID) } - if network.Name != p.state.NetworkName { - return fmt.Errorf("network name mismatch: %s != %s", network.Name, p.state.NetworkName) + if network.Name != state.NetworkName { + return fmt.Errorf("network name mismatch: %s != %s", network.Name, state.NetworkName) } if len(network.IPAM.Config) != 1 { return fmt.Errorf("unexpected number of IPAM configs: %d", len(network.IPAM.Config)) } - if network.IPAM.Config[0].Subnet != p.state.NetworkCIDR { - return fmt.Errorf("network CIDR mismatch: %s != %s", network.IPAM.Config[0].Subnet, p.state.NetworkCIDR) + if network.IPAM.Config[0].Subnet != state.NetworkCIDR { + return fmt.Errorf("network CIDR mismatch: %s != %s", network.IPAM.Config[0].Subnet, state.NetworkCIDR) } return nil } func (p *Provider) destroyNetwork(ctx context.Context) error { - return p.dockerClient.NetworkRemove(ctx, p.state.NetworkID) + return p.dockerClient.NetworkRemove(ctx, p.GetState().NetworkID) } // openListenerOnFreePort opens the next free port diff --git a/core/provider/docker/provider.go b/core/provider/docker/provider.go index f063630..ffe69cd 100644 --- a/core/provider/docker/provider.go +++ b/core/provider/docker/provider.go @@ -4,12 +4,12 @@ import ( "context" "encoding/json" "fmt" - "io" "net" "sync" "github.com/docker/docker/api/types/image" "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/clients" "github.com/cilium/ipam/service/ipallocator" "github.com/docker/docker/api/types/network" @@ -17,8 +17,6 @@ import ( "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/mount" - - "github.com/docker/docker/client" ) type ProviderState struct { @@ -38,7 +36,7 @@ type Provider struct { state *ProviderState stateMu sync.Mutex - dockerClient *client.Client + dockerClient clients.DockerClient dockerNetworkAllocator *ipallocator.Range networkMu sync.Mutex logger *zap.Logger @@ -47,7 +45,7 @@ type Provider struct { var _ provider.ProviderI = (*Provider)(nil) func CreateProvider(ctx context.Context, logger *zap.Logger, providerName string) (*Provider, error) { - dockerClient, err := client.NewClientWithOpts() + dockerClient, err := clients.NewDockerClient("") if err != nil { return nil, err } @@ -120,7 +118,7 @@ func RestoreProvider(ctx context.Context, logger *zap.Logger, state []byte) (*Pr logger: logger, } - dockerClient, err := client.NewClientWithOpts() + dockerClient, err := clients.NewDockerClient("") if err != nil { return nil, err } @@ -165,15 +163,17 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin if err := definition.ValidateBasic(); err != nil { return &Task{}, fmt.Errorf("failed to validate task definition: %w", err) } + state := p.GetState() taskState := &TaskState{ - Name: definition.Name, - Definition: definition, + Name: definition.Name, + Definition: definition, + BuilderImageName: state.BuilderImageName, } logger := p.logger.Named("docker_provider") - if err := p.pullImage(ctx, definition.Image.Image); err != nil { + if err := p.dockerClient.ImagePull(ctx, p.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return nil, err } @@ -236,17 +236,17 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin Tty: false, Hostname: definition.Name, Labels: map[string]string{ - providerLabelName: p.state.Name, + providerLabelName: state.Name, }, Env: convertEnvMapToList(definition.Environment), ExposedPorts: portSet, }, &container.HostConfig{ Mounts: mounts, PortBindings: portBindings, - NetworkMode: container.NetworkMode(p.state.NetworkName), + NetworkMode: container.NetworkMode(state.NetworkName), }, &network.NetworkingConfig{ EndpointsConfig: map[string]*network.EndpointSettings{ - p.state.NetworkName: { + state.NetworkName: { IPAMConfig: &network.EndpointIPAMConfig{ IPv4Address: ip, }, @@ -260,6 +260,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin taskState.Id = createdContainer.ID taskState.Status = provider.TASK_STOPPED + taskState.NetworkName = state.NetworkName taskState.IpAddress = ip p.stateMu.Lock() @@ -268,8 +269,10 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin p.state.TaskStates[taskState.Id] = taskState return &Task{ - state: taskState, - provider: p, + state: taskState, + logger: p.logger.With(zap.String("task", definition.Name)), + dockerClient: p.dockerClient, + removeTask: p.removeTask, }, nil } @@ -307,8 +310,10 @@ func (p *Provider) DeserializeTask(ctx context.Context, bz []byte) (provider.Tas } task := &Task{ - provider: p, - state: &taskState, + state: &taskState, + logger: p.logger.With(zap.String("task", taskState.Name)), + dockerClient: p.dockerClient, + removeTask: p.removeTask, } if err := task.ensureTask(ctx); err != nil { @@ -327,27 +332,10 @@ func (p *Provider) removeTask(_ context.Context, taskID string) error { return nil } -func (p *Provider) pullImage(ctx context.Context, imageName string) error { - _, _, err := p.dockerClient.ImageInspectWithRaw(ctx, imageName) - if err != nil { - p.logger.Info("image not found, pulling", zap.String("image", imageName)) - resp, err := p.dockerClient.ImagePull(ctx, imageName, image.PullOptions{}) - if err != nil { - return err - } - defer resp.Close() - - // throw away the image pull stdout response - _, err = io.Copy(io.Discard, resp) - return err - } - return nil -} - func (p *Provider) Teardown(ctx context.Context) error { p.logger.Info("tearing down Docker provider") - for _, task := range p.state.TaskStates { + for _, task := range p.GetState().TaskStates { if err := p.dockerClient.ContainerRemove(ctx, task.Id, container.RemoveOptions{ Force: true, }); err != nil { diff --git a/core/provider/docker/provider_test.go b/core/provider/docker/provider_test.go index e381a57..01b2562 100644 --- a/core/provider/docker/provider_test.go +++ b/core/provider/docker/provider_test.go @@ -7,6 +7,8 @@ import ( "testing" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/network" "github.com/docker/docker/client" @@ -231,12 +233,11 @@ func TestConcurrentTaskCreation(t *testing.T) { for task := range tasks { taskState := task.GetState() - client, _ := client.NewClientWithOpts() - containerJSON, err := client.ContainerInspect(ctx, taskState.Id) + dockerClient, _ := clients.NewDockerClient("") + containerJSON, err := dockerClient.ContainerInspect(ctx, taskState.Id) require.NoError(t, err) ip := containerJSON.NetworkSettings.Networks[providerState.NetworkName].IPAddress - fmt.Println(ip) assert.False(t, ips[ip], "Duplicate IP found: %s", ip) ips[ip] = true } diff --git a/core/provider/docker/task.go b/core/provider/docker/task.go index ea569f8..6ee04da 100644 --- a/core/provider/docker/task.go +++ b/core/provider/docker/task.go @@ -7,7 +7,10 @@ import ( "sync" "time" + "github.com/skip-mev/petri/core/v3/provider/clients" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/mount" "github.com/docker/docker/pkg/stdcopy" "github.com/docker/go-connections/nat" "github.com/skip-mev/petri/core/v3/provider" @@ -16,12 +19,14 @@ import ( ) type TaskState struct { - Id string `json:"id"` - Name string `json:"name"` - Volume *VolumeState `json:"volumes"` - Definition provider.TaskDefinition `json:"definition"` - Status provider.TaskStatus `json:"status"` - IpAddress string `json:"ip_address"` + Id string `json:"id"` + Name string `json:"name"` + Volume *VolumeState `json:"volumes"` + Definition provider.TaskDefinition `json:"definition"` + Status provider.TaskStatus `json:"status"` + IpAddress string `json:"ip_address"` + BuilderImageName string `json:"builder_image_name"` + NetworkName string `json:"network_name"` } type VolumeState struct { @@ -30,18 +35,20 @@ type VolumeState struct { } type Task struct { - state *TaskState - stateMu sync.Mutex - provider *Provider + state *TaskState + stateMu sync.Mutex + logger *zap.Logger + dockerClient clients.DockerClient + removeTask provider.RemoveTaskFunc } var _ provider.TaskI = (*Task)(nil) func (t *Task) Start(ctx context.Context) error { - t.provider.logger.Info("starting task", zap.String("id", t.state.Id)) - - err := t.provider.dockerClient.ContainerStart(ctx, t.state.Id, container.StartOptions{}) + state := t.GetState() + t.logger.Info("starting task", zap.String("id", state.Id)) + err := t.dockerClient.ContainerStart(ctx, state.Id, container.StartOptions{}) if err != nil { return err } @@ -59,10 +66,10 @@ func (t *Task) Start(ctx context.Context) error { } func (t *Task) Stop(ctx context.Context) error { - t.provider.logger.Info("stopping task", zap.String("id", t.state.Id)) - - err := t.provider.dockerClient.ContainerStop(ctx, t.state.Id, container.StopOptions{}) + state := t.GetState() + t.logger.Info("stopping task", zap.String("id", state.Id)) + err := t.dockerClient.ContainerStop(ctx, state.Id, container.StopOptions{}) if err != nil { return err } @@ -80,9 +87,10 @@ func (t *Task) Stop(ctx context.Context) error { } func (t *Task) Destroy(ctx context.Context) error { - t.provider.logger.Info("destroying task", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Info("destroying task", zap.String("id", state.Id)) - err := t.provider.dockerClient.ContainerRemove(ctx, t.state.Id, container.RemoveOptions{ + err := t.dockerClient.ContainerRemove(ctx, state.Id, container.RemoveOptions{ Force: true, RemoveVolumes: true, }) @@ -91,7 +99,7 @@ func (t *Task) Destroy(ctx context.Context) error { return err } - if err := t.provider.removeTask(ctx, t.state.Id); err != nil { + if err := t.removeTask(ctx, state.Id); err != nil { return err } @@ -99,16 +107,15 @@ func (t *Task) Destroy(ctx context.Context) error { } func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, error) { - t.provider.logger.Debug("getting external address", zap.String("id", t.state.Id)) - - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + state := t.GetState() + t.logger.Debug("getting external address", zap.String("id", state.Id)) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return "", fmt.Errorf("failed to inspect container: %w", err) } portBindings, ok := dockerContainer.NetworkSettings.Ports[nat.Port(fmt.Sprintf("%s/tcp", port))] - if !ok || len(portBindings) == 0 { return "", fmt.Errorf("port %s not found", port) } @@ -117,15 +124,15 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err } func (t *Task) GetIP(ctx context.Context) (string, error) { - t.provider.logger.Debug("getting IP", zap.String("id", t.state.Id)) + state := t.GetState() + t.logger.Debug("getting IP", zap.String("id", state.Id)) - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return "", err } - ip := dockerContainer.NetworkSettings.Networks[t.provider.state.NetworkName].IPAMConfig.IPv4Address - + ip := dockerContainer.NetworkSettings.Networks[state.NetworkName].IPAMConfig.IPv4Address return ip, nil } @@ -149,7 +156,7 @@ func (t *Task) WaitForStatus(ctx context.Context, interval time.Duration, desire } func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { - containerJSON, err := t.provider.dockerClient.ContainerInspect(ctx, t.state.Id) + containerJSON, err := t.dockerClient.ContainerInspect(ctx, t.GetState().Id) if err != nil { return provider.TASK_STATUS_UNDEFINED, err } @@ -162,7 +169,7 @@ func (t *Task) GetStatus(ctx context.Context) (provider.TaskStatus, error) { case "paused": return provider.TASK_PAUSED, nil case "restarting": - return provider.TASK_RUNNING, nil // todo(zygimantass): is this sane? + return provider.TASK_RESTARTING, nil case "removing": return provider.TASK_STOPPED, nil case "exited": @@ -192,15 +199,16 @@ func (t *Task) RunCommand(ctx context.Context, cmd []string) (string, string, in } func (t *Task) runCommand(ctx context.Context, cmd []string) (string, string, int, error) { - t.provider.logger.Debug("running command", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + state := t.GetState() + t.logger.Debug("running command", zap.String("id", state.Id), zap.Strings("command", cmd)) - exec, err := t.provider.dockerClient.ContainerExecCreate(ctx, t.state.Id, container.ExecOptions{ + exec, err := t.dockerClient.ContainerExecCreate(ctx, state.Id, container.ExecOptions{ AttachStdout: true, AttachStderr: true, Cmd: cmd, }) if err != nil { - if buf, err := t.provider.dockerClient.ContainerLogs(ctx, t.state.Id, container.LogsOptions{ + if buf, err := t.dockerClient.ContainerLogs(ctx, state.Id, container.LogsOptions{ ShowStdout: true, ShowStderr: true, }); err == nil { @@ -213,7 +221,7 @@ func (t *Task) runCommand(ctx context.Context, cmd []string) (string, string, in return "", "", 0, err } - resp, err := t.provider.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) + resp, err := t.dockerClient.ContainerExecAttach(ctx, exec.ID, container.ExecAttachOptions{}) if err != nil { return "", "", 0, err } @@ -232,7 +240,7 @@ loop: case <-ctx.Done(): return "", "", lastExitCode, ctx.Err() case <-ticker.C: - execInspect, err := t.provider.dockerClient.ContainerExecInspect(ctx, exec.ID) + execInspect, err := t.dockerClient.ContainerExecInspect(ctx, exec.ID) if err != nil { return "", "", lastExitCode, err } @@ -247,7 +255,7 @@ loop: } if err != nil { - t.provider.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.state.Id)) + t.logger.Error("failed to wait for exec", zap.Error(err), zap.String("id", t.GetState().Id)) return "", "", lastExitCode, err } @@ -261,12 +269,13 @@ loop: } func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string, string, int, error) { + state := t.GetState() definition := t.GetState().Definition if err := definition.ValidateBasic(); err != nil { return "", "", 0, fmt.Errorf("failed to validate task definition: %w", err) } - t.provider.logger.Debug("running command while stopped", zap.String("id", t.state.Id), zap.Strings("command", cmd)) + t.logger.Debug("running command while stopped", zap.String("id", state.Id), zap.Strings("command", cmd)) status, err := t.GetStatus(ctx) if err != nil { @@ -282,24 +291,61 @@ func (t *Task) runCommandWhileStopped(ctx context.Context, cmd []string) (string definition.ContainerName = fmt.Sprintf("%s-executor-%s-%d", definition.Name, util.RandomString(5), time.Now().Unix()) definition.Ports = []string{} - task, err := t.provider.CreateTask(ctx, definition) - if err != nil { - return "", "", 0, err + containerConfig := &container.Config{ + Image: definition.Image.Image, + Entrypoint: definition.Entrypoint, + Cmd: definition.Command, + Tty: false, + Hostname: definition.Name, + Env: convertEnvMapToList(definition.Environment), + } + + var mounts []mount.Mount + if state.Volume != nil { + mounts = []mount.Mount{ + { + Type: mount.TypeVolume, + Source: state.Volume.Name, + Target: definition.DataDir, + }, + } } - err = task.Start(ctx) - defer task.Destroy(ctx) // nolint:errcheck + hostConfig := &container.HostConfig{ + NetworkMode: container.NetworkMode(state.NetworkName), + Mounts: mounts, + } + resp, err := t.dockerClient.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, definition.ContainerName) if err != nil { return "", "", 0, err } - stdout, stderr, exitCode, err := task.RunCommand(ctx, cmd) + tempTask := &Task{ + state: &TaskState{ + Id: resp.ID, + Name: definition.Name, + Definition: definition, + Status: provider.TASK_STOPPED, + NetworkName: state.NetworkName, + }, + logger: t.logger.With(zap.String("temp_task", definition.Name)), + dockerClient: t.dockerClient, + removeTask: t.removeTask, + } + + err = tempTask.Start(ctx) if err != nil { return "", "", 0, err } - return stdout, stderr, exitCode, nil + defer func() { + if err := tempTask.Destroy(ctx); err != nil { + t.logger.Error("failed to destroy temporary task", zap.Error(err)) + } + }() + + return tempTask.RunCommand(ctx, cmd) } func (t *Task) GetState() TaskState { @@ -311,7 +357,7 @@ func (t *Task) GetState() TaskState { func (t *Task) ensureTask(ctx context.Context) error { state := t.GetState() - dockerContainer, err := t.provider.dockerClient.ContainerInspect(ctx, state.Id) + dockerContainer, err := t.dockerClient.ContainerInspect(ctx, state.Id) if err != nil { return fmt.Errorf("failed to inspect container: %w", err) } @@ -339,17 +385,18 @@ func (t *Task) ensureTask(ctx context.Context) error { } func (t *Task) ensureVolume(ctx context.Context) error { - if t.state.Volume == nil { + state := t.GetState() + if state.Volume == nil { return nil } - volume, err := t.provider.dockerClient.VolumeInspect(ctx, t.state.Volume.Name) + volume, err := t.dockerClient.VolumeInspect(ctx, state.Volume.Name) if err != nil { return fmt.Errorf("failed to inspect volume: %w", err) } - if volume.Name != t.state.Volume.Name { - return fmt.Errorf("volume name mismatch, expected: %s, got: %s", t.state.Volume.Name, volume.Name) + if volume.Name != state.Volume.Name { + return fmt.Errorf("volume name mismatch, expected: %s, got: %s", state.Volume.Name, volume.Name) } return nil diff --git a/core/provider/docker/task_test.go b/core/provider/docker/task_test.go index a222baf..237d00e 100644 --- a/core/provider/docker/task_test.go +++ b/core/provider/docker/task_test.go @@ -54,8 +54,10 @@ func TestTaskLifecycle(t *testing.T) { err = task.Stop(ctx) require.NoError(t, err) + require.Equal(t, 1, len(p.GetState().TaskStates)) err = task.Destroy(ctx) require.NoError(t, err) + require.Equal(t, 0, len(p.GetState().TaskStates)) dockerTask, ok := task.(*docker.Task) require.True(t, ok) diff --git a/core/provider/docker/volume.go b/core/provider/docker/volume.go index 1cd0333..68b13f9 100644 --- a/core/provider/docker/volume.go +++ b/core/provider/docker/volume.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "fmt" + "github.com/docker/docker/api/types/image" "io" "os" "path" @@ -56,7 +57,6 @@ func (p *Provider) DestroyVolume(ctx context.Context, id string) error { func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -64,7 +64,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", state.Id), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", state.Id), zap.String("path", relPath)) logger.Debug("writing file") @@ -72,16 +72,16 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } logger.Debug("creating writefile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Entrypoint: []string{"sh", "-c"}, Cmd: []string{ @@ -93,10 +93,6 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string mountPath, }, - Labels: map[string]string{ - providerLabelName: providerState.Name, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0:0", }, @@ -121,12 +117,12 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string return } - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed to remove writefile container", zap.String("id", cc.ID), zap.Error(err)) @@ -143,7 +139,7 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string defer file.Close() - if err := t.provider.dockerClient.CopyToContainer( + if err := t.dockerClient.CopyToContainer( ctx, cc.ID, mountPath, @@ -154,11 +150,11 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string } logger.Debug("starting writefile container") - if err := t.provider.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { + if err := t.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { return fmt.Errorf("starting write-file container: %w", err) } - waitCh, errCh := t.provider.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) + waitCh, errCh := t.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) select { case <-ctx.Done(): return ctx.Err() @@ -182,7 +178,6 @@ func (t *Task) WriteTar(ctx context.Context, relPath string, localTarPath string // taken from strangelove-ventures/interchain-test func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -190,7 +185,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) logger.Debug("writing file") @@ -198,16 +193,16 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er containerName := fmt.Sprintf("petri-writefile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } logger.Debug("creating writefile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, + Image: state.BuilderImageName, Entrypoint: []string{"sh", "-c"}, Cmd: []string{ @@ -219,10 +214,6 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er mountPath, }, - Labels: map[string]string{ - providerLabelName: providerState.Name, - }, - // Use root user to avoid permission issues when reading files from the volume. User: "0:0", }, @@ -247,12 +238,12 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er return } - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed to remove writefile container", zap.String("id", cc.ID), zap.Error(err)) @@ -282,7 +273,7 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er logger.Debug("copying file to container") - if err := t.provider.dockerClient.CopyToContainer( + if err := t.dockerClient.CopyToContainer( ctx, cc.ID, mountPath, @@ -293,11 +284,11 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er } logger.Debug("starting writefile container") - if err := t.provider.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { + if err := t.dockerClient.ContainerStart(ctx, cc.ID, container.StartOptions{}); err != nil { return fmt.Errorf("starting write-file container: %w", err) } - waitCh, errCh := t.provider.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) + waitCh, errCh := t.dockerClient.ContainerWait(ctx, cc.ID, container.WaitConditionNotRunning) select { case <-ctx.Done(): return ctx.Err() @@ -320,7 +311,6 @@ func (t *Task) WriteFile(ctx context.Context, relPath string, content []byte) er func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return nil, fmt.Errorf("no volumes found for container %s", state.Id) @@ -328,26 +318,22 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath)) const mountPath = "/mnt/dockervolume" containerName := fmt.Sprintf("petri-getfile-%d", time.Now().UnixNano()) - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return nil, err } logger.Debug("creating getfile container") - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, - - Labels: map[string]string{ - providerLabelName: providerState.Name, - }, + Image: state.BuilderImageName, // Use root user to avoid permission issues when reading files from the volume. User: "0", @@ -367,12 +353,12 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { logger.Debug("created getfile container", zap.String("id", cc.ID)) defer func() { - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { // auto-removed, but not detected as autoremoved return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed cleaning up the getfile container", zap.Error(err)) @@ -380,7 +366,7 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { }() logger.Debug("copying from container") - rc, _, err := t.provider.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) + rc, _, err := t.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) if err != nil { return nil, fmt.Errorf("copying from container: %w", err) } @@ -410,7 +396,6 @@ func (t *Task) ReadFile(ctx context.Context, relPath string) ([]byte, error) { func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error { state := t.GetState() - providerState := t.provider.GetState() if state.Volume == nil { return fmt.Errorf("no volumes found for container %s", state.Id) @@ -418,7 +403,7 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error volumeName := state.Volume.Name - logger := t.provider.logger.With(zap.String("volume", volumeName), zap.String("path", relPath), zap.String("localPath", localPath)) + logger := t.logger.With(zap.String("volume", volumeName), zap.String("path", relPath), zap.String("localPath", localPath)) const mountPath = "/mnt/dockervolume" @@ -426,19 +411,14 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error logger.Debug("creating getdir container") - if err := t.provider.pullImage(ctx, providerState.BuilderImageName); err != nil { + if err := t.dockerClient.ImagePull(ctx, t.logger, state.BuilderImageName, image.PullOptions{}); err != nil { return err } - cc, err := t.provider.dockerClient.ContainerCreate( + cc, err := t.dockerClient.ContainerCreate( ctx, &container.Config{ - Image: providerState.BuilderImageName, - - Labels: map[string]string{ - providerLabelName: providerState.Name, - }, - + Image: state.BuilderImageName, // Use root user to avoid permission issues when reading files from the volume. User: "0", }, @@ -455,11 +435,11 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error } defer func() { - if _, err := t.provider.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { + if _, err := t.dockerClient.ContainerInspect(ctx, cc.ID); err != nil && client.IsErrNotFound(err) { return } - if err := t.provider.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ + if err := t.dockerClient.ContainerRemove(ctx, cc.ID, container.RemoveOptions{ Force: true, }); err != nil { logger.Error("failed cleaning up the getdir container", zap.Error(err)) @@ -467,7 +447,7 @@ func (t *Task) DownloadDir(ctx context.Context, relPath, localPath string) error }() logger.Debug("copying from container") - reader, _, err := t.provider.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) + reader, _, err := t.dockerClient.CopyFromContainer(ctx, cc.ID, path.Join(mountPath, relPath)) if err != nil { return err } @@ -513,7 +493,7 @@ func (p *Provider) SetVolumeOwner(ctx context.Context, volumeName, uid, gid stri containerName := fmt.Sprintf("petri-setowner-%d", time.Now().UnixNano()) - if err := p.pullImage(ctx, p.GetState().BuilderImageName); err != nil { + if err := p.dockerClient.ImagePull(ctx, p.logger, p.GetState().BuilderImageName, image.PullOptions{}); err != nil { return err } diff --git a/core/provider/mocks/docker_client_mock.go b/core/provider/mocks/docker_client_mock.go new file mode 100644 index 0000000..f4e3fe7 --- /dev/null +++ b/core/provider/mocks/docker_client_mock.go @@ -0,0 +1,662 @@ +// Code generated by mockery v2.47.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + container "github.com/docker/docker/api/types/container" + + image "github.com/docker/docker/api/types/image" + + io "io" + + mock "github.com/stretchr/testify/mock" + + network "github.com/docker/docker/api/types/network" + + types "github.com/docker/docker/api/types" + + v1 "github.com/opencontainers/image-spec/specs-go/v1" + + volume "github.com/docker/docker/api/types/volume" + + zap "go.uber.org/zap" +) + +// DockerClient is an autogenerated mock type for the DockerClient type +type DockerClient struct { + mock.Mock +} + +// Close provides a mock function with given fields: +func (_m *DockerClient) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerCreate provides a mock function with given fields: ctx, config, hostConfig, networkingConfig, platform, containerName +func (_m *DockerClient) ContainerCreate(ctx context.Context, config *container.Config, hostConfig *container.HostConfig, networkingConfig *network.NetworkingConfig, platform *v1.Platform, containerName string) (container.CreateResponse, error) { + ret := _m.Called(ctx, config, hostConfig, networkingConfig, platform, containerName) + + if len(ret) == 0 { + panic("no return value specified for ContainerCreate") + } + + var r0 container.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) (container.CreateResponse, error)); ok { + return rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } + if rf, ok := ret.Get(0).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) container.CreateResponse); ok { + r0 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r0 = ret.Get(0).(container.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, *container.Config, *container.HostConfig, *network.NetworkingConfig, *v1.Platform, string) error); ok { + r1 = rf(ctx, config, hostConfig, networkingConfig, platform, containerName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecAttach provides a mock function with given fields: ctx, execID, config +func (_m *DockerClient) ContainerExecAttach(ctx context.Context, execID string, config container.ExecStartOptions) (types.HijackedResponse, error) { + ret := _m.Called(ctx, execID, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecAttach") + } + + var r0 types.HijackedResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecStartOptions) (types.HijackedResponse, error)); ok { + return rf(ctx, execID, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecStartOptions) types.HijackedResponse); ok { + r0 = rf(ctx, execID, config) + } else { + r0 = ret.Get(0).(types.HijackedResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.ExecStartOptions) error); ok { + r1 = rf(ctx, execID, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecCreate provides a mock function with given fields: ctx, _a1, config +func (_m *DockerClient) ContainerExecCreate(ctx context.Context, _a1 string, config container.ExecOptions) (types.IDResponse, error) { + ret := _m.Called(ctx, _a1, config) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecCreate") + } + + var r0 types.IDResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecOptions) (types.IDResponse, error)); ok { + return rf(ctx, _a1, config) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.ExecOptions) types.IDResponse); ok { + r0 = rf(ctx, _a1, config) + } else { + r0 = ret.Get(0).(types.IDResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.ExecOptions) error); ok { + r1 = rf(ctx, _a1, config) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerExecInspect provides a mock function with given fields: ctx, execID +func (_m *DockerClient) ContainerExecInspect(ctx context.Context, execID string) (container.ExecInspect, error) { + ret := _m.Called(ctx, execID) + + if len(ret) == 0 { + panic("no return value specified for ContainerExecInspect") + } + + var r0 container.ExecInspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (container.ExecInspect, error)); ok { + return rf(ctx, execID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) container.ExecInspect); ok { + r0 = rf(ctx, execID) + } else { + r0 = ret.Get(0).(container.ExecInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, execID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerInspect provides a mock function with given fields: ctx, _a1 +func (_m *DockerClient) ContainerInspect(ctx context.Context, _a1 string) (types.ContainerJSON, error) { + ret := _m.Called(ctx, _a1) + + if len(ret) == 0 { + panic("no return value specified for ContainerInspect") + } + + var r0 types.ContainerJSON + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ContainerJSON, error)); ok { + return rf(ctx, _a1) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ContainerJSON); ok { + r0 = rf(ctx, _a1) + } else { + r0 = ret.Get(0).(types.ContainerJSON) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, _a1) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerList provides a mock function with given fields: ctx, options +func (_m *DockerClient) ContainerList(ctx context.Context, options container.ListOptions) ([]types.Container, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerList") + } + + var r0 []types.Container + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) ([]types.Container, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, container.ListOptions) []types.Container); ok { + r0 = rf(ctx, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]types.Container) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, container.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerLogs provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerLogs(ctx context.Context, _a1 string, options container.LogsOptions) (io.ReadCloser, error) { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerLogs") + } + + var r0 io.ReadCloser + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.LogsOptions) (io.ReadCloser, error)); ok { + return rf(ctx, _a1, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.LogsOptions) io.ReadCloser); ok { + r0 = rf(ctx, _a1, options) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.LogsOptions) error); ok { + r1 = rf(ctx, _a1, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ContainerRemove provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerRemove(ctx context.Context, _a1 string, options container.RemoveOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.RemoveOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStart provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerStart(ctx context.Context, _a1 string, options container.StartOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStart") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StartOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerStop provides a mock function with given fields: ctx, _a1, options +func (_m *DockerClient) ContainerStop(ctx context.Context, _a1 string, options container.StopOptions) error { + ret := _m.Called(ctx, _a1, options) + + if len(ret) == 0 { + panic("no return value specified for ContainerStop") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, container.StopOptions) error); ok { + r0 = rf(ctx, _a1, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ContainerWait provides a mock function with given fields: ctx, containerID, condition +func (_m *DockerClient) ContainerWait(ctx context.Context, containerID string, condition container.WaitCondition) (<-chan container.WaitResponse, <-chan error) { + ret := _m.Called(ctx, containerID, condition) + + if len(ret) == 0 { + panic("no return value specified for ContainerWait") + } + + var r0 <-chan container.WaitResponse + var r1 <-chan error + if rf, ok := ret.Get(0).(func(context.Context, string, container.WaitCondition) (<-chan container.WaitResponse, <-chan error)); ok { + return rf(ctx, containerID, condition) + } + if rf, ok := ret.Get(0).(func(context.Context, string, container.WaitCondition) <-chan container.WaitResponse); ok { + r0 = rf(ctx, containerID, condition) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(<-chan container.WaitResponse) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, container.WaitCondition) <-chan error); ok { + r1 = rf(ctx, containerID, condition) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(<-chan error) + } + } + + return r0, r1 +} + +// CopyFromContainer provides a mock function with given fields: ctx, _a1, srcPath +func (_m *DockerClient) CopyFromContainer(ctx context.Context, _a1 string, srcPath string) (io.ReadCloser, container.PathStat, error) { + ret := _m.Called(ctx, _a1, srcPath) + + if len(ret) == 0 { + panic("no return value specified for CopyFromContainer") + } + + var r0 io.ReadCloser + var r1 container.PathStat + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (io.ReadCloser, container.PathStat, error)); ok { + return rf(ctx, _a1, srcPath) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string) io.ReadCloser); ok { + r0 = rf(ctx, _a1, srcPath) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(io.ReadCloser) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string) container.PathStat); ok { + r1 = rf(ctx, _a1, srcPath) + } else { + r1 = ret.Get(1).(container.PathStat) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, string) error); ok { + r2 = rf(ctx, _a1, srcPath) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// CopyToContainer provides a mock function with given fields: ctx, _a1, path, content, options +func (_m *DockerClient) CopyToContainer(ctx context.Context, _a1 string, path string, content io.Reader, options container.CopyToContainerOptions) error { + ret := _m.Called(ctx, _a1, path, content, options) + + if len(ret) == 0 { + panic("no return value specified for CopyToContainer") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, io.Reader, container.CopyToContainerOptions) error); ok { + r0 = rf(ctx, _a1, path, content, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// ImageInspectWithRaw provides a mock function with given fields: ctx, imageID +func (_m *DockerClient) ImageInspectWithRaw(ctx context.Context, imageID string) (types.ImageInspect, []byte, error) { + ret := _m.Called(ctx, imageID) + + if len(ret) == 0 { + panic("no return value specified for ImageInspectWithRaw") + } + + var r0 types.ImageInspect + var r1 []byte + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) (types.ImageInspect, []byte, error)); ok { + return rf(ctx, imageID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) types.ImageInspect); ok { + r0 = rf(ctx, imageID) + } else { + r0 = ret.Get(0).(types.ImageInspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) []byte); ok { + r1 = rf(ctx, imageID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([]byte) + } + } + + if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { + r2 = rf(ctx, imageID) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// ImagePull provides a mock function with given fields: ctx, logger, refStr, options +func (_m *DockerClient) ImagePull(ctx context.Context, logger *zap.Logger, refStr string, options image.PullOptions) error { + ret := _m.Called(ctx, logger, refStr, options) + + if len(ret) == 0 { + panic("no return value specified for ImagePull") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *zap.Logger, string, image.PullOptions) error); ok { + r0 = rf(ctx, logger, refStr, options) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NetworkCreate provides a mock function with given fields: ctx, name, options +func (_m *DockerClient) NetworkCreate(ctx context.Context, name string, options network.CreateOptions) (network.CreateResponse, error) { + ret := _m.Called(ctx, name, options) + + if len(ret) == 0 { + panic("no return value specified for NetworkCreate") + } + + var r0 network.CreateResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, network.CreateOptions) (network.CreateResponse, error)); ok { + return rf(ctx, name, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, network.CreateOptions) network.CreateResponse); ok { + r0 = rf(ctx, name, options) + } else { + r0 = ret.Get(0).(network.CreateResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, network.CreateOptions) error); ok { + r1 = rf(ctx, name, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NetworkInspect provides a mock function with given fields: ctx, networkID, options +func (_m *DockerClient) NetworkInspect(ctx context.Context, networkID string, options network.InspectOptions) (network.Inspect, error) { + ret := _m.Called(ctx, networkID, options) + + if len(ret) == 0 { + panic("no return value specified for NetworkInspect") + } + + var r0 network.Inspect + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, network.InspectOptions) (network.Inspect, error)); ok { + return rf(ctx, networkID, options) + } + if rf, ok := ret.Get(0).(func(context.Context, string, network.InspectOptions) network.Inspect); ok { + r0 = rf(ctx, networkID, options) + } else { + r0 = ret.Get(0).(network.Inspect) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, network.InspectOptions) error); ok { + r1 = rf(ctx, networkID, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// NetworkRemove provides a mock function with given fields: ctx, networkID +func (_m *DockerClient) NetworkRemove(ctx context.Context, networkID string) error { + ret := _m.Called(ctx, networkID) + + if len(ret) == 0 { + panic("no return value specified for NetworkRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, networkID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Ping provides a mock function with given fields: ctx +func (_m *DockerClient) Ping(ctx context.Context) (types.Ping, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 types.Ping + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (types.Ping, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) types.Ping); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(types.Ping) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeCreate provides a mock function with given fields: ctx, options +func (_m *DockerClient) VolumeCreate(ctx context.Context, options volume.CreateOptions) (volume.Volume, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for VolumeCreate") + } + + var r0 volume.Volume + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, volume.CreateOptions) (volume.Volume, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, volume.CreateOptions) volume.Volume); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(volume.Volume) + } + + if rf, ok := ret.Get(1).(func(context.Context, volume.CreateOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeInspect provides a mock function with given fields: ctx, volumeID +func (_m *DockerClient) VolumeInspect(ctx context.Context, volumeID string) (volume.Volume, error) { + ret := _m.Called(ctx, volumeID) + + if len(ret) == 0 { + panic("no return value specified for VolumeInspect") + } + + var r0 volume.Volume + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (volume.Volume, error)); ok { + return rf(ctx, volumeID) + } + if rf, ok := ret.Get(0).(func(context.Context, string) volume.Volume); ok { + r0 = rf(ctx, volumeID) + } else { + r0 = ret.Get(0).(volume.Volume) + } + + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, volumeID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeList provides a mock function with given fields: ctx, options +func (_m *DockerClient) VolumeList(ctx context.Context, options volume.ListOptions) (volume.ListResponse, error) { + ret := _m.Called(ctx, options) + + if len(ret) == 0 { + panic("no return value specified for VolumeList") + } + + var r0 volume.ListResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, volume.ListOptions) (volume.ListResponse, error)); ok { + return rf(ctx, options) + } + if rf, ok := ret.Get(0).(func(context.Context, volume.ListOptions) volume.ListResponse); ok { + r0 = rf(ctx, options) + } else { + r0 = ret.Get(0).(volume.ListResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, volume.ListOptions) error); ok { + r1 = rf(ctx, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// VolumeRemove provides a mock function with given fields: ctx, volumeID, force +func (_m *DockerClient) VolumeRemove(ctx context.Context, volumeID string, force bool) error { + ret := _m.Called(ctx, volumeID, force) + + if len(ret) == 0 { + panic("no return value specified for VolumeRemove") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) error); ok { + r0 = rf(ctx, volumeID, force) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewDockerClient creates a new instance of DockerClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewDockerClient(t interface { + mock.TestingT + Cleanup(func()) +}) *DockerClient { + mock := &DockerClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/core/provider/provider.go b/core/provider/provider.go index 64213df..3ac73c2 100644 --- a/core/provider/provider.go +++ b/core/provider/provider.go @@ -10,11 +10,15 @@ import ( // TaskStatus defines the status of a task's underlying workload type TaskStatus int +// RemoveTaskFunc is a callback function type for removing a task from its provider +type RemoveTaskFunc func(ctx context.Context, taskID string) error + const ( TASK_STATUS_UNDEFINED TaskStatus = iota TASK_RUNNING TASK_STOPPED TASK_PAUSED + TASK_RESTARTING ) // Task is a stateful object that holds the underlying workload's details and tracks the workload's lifecycle diff --git a/core/types/chain.go b/core/types/chain.go index 7d2afed..490ad09 100644 --- a/core/types/chain.go +++ b/core/types/chain.go @@ -100,7 +100,7 @@ func (c ChainConfig) GetGenesisDelegation() *big.Int { return c.GenesisDelegation } -func (c *ChainConfig) ValidateBasic() error { +func (c ChainConfig) ValidateBasic() error { if c.Denom == "" { return fmt.Errorf("denom cannot be empty") } diff --git a/cosmos/chain/chain.go b/cosmos/chain/chain.go index 4f00e5c..3e917b8 100644 --- a/cosmos/chain/chain.go +++ b/cosmos/chain/chain.go @@ -150,7 +150,8 @@ func RestoreChain(ctx context.Context, logger *zap.Logger, infraProvider provide } chain := Chain{ - State: packagedState.State, + State: packagedState.State, + logger: logger, } for _, vs := range packagedState.ValidatorStates { diff --git a/cosmos/examples/digitalocean_simapp.go b/cosmos/examples/digitalocean_simapp.go new file mode 100644 index 0000000..6d3d778 --- /dev/null +++ b/cosmos/examples/digitalocean_simapp.go @@ -0,0 +1,128 @@ +package examples + +import ( + "context" + "github.com/skip-mev/petri/core/v3/provider/digitalocean" + "io" + "net/http" + "os" + "strings" + + "github.com/cosmos/cosmos-sdk/crypto/hd" + + "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/skip-mev/petri/cosmos/v3/node" + + "github.com/skip-mev/petri/core/v3/provider" + petritypes "github.com/skip-mev/petri/core/v3/types" + "go.uber.org/zap" +) + +func main() { + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + doToken := os.Getenv("DO_API_TOKEN") + if doToken == "" { + logger.Fatal("DO_API_TOKEN environment variable not set") + } + + imageID := os.Getenv("DO_IMAGE_ID") + if imageID == "" { + logger.Fatal("DO_IMAGE_ID environment variable not set") + } + + sshKeyPair, err := digitalocean.MakeSSHKeyPair() + if err != nil { + logger.Fatal("failed to create SSH key pair", zap.Error(err)) + } + + externalIP, err := getExternalIP() + logger.Info("External IP", zap.String("address", externalIP)) + + doProvider, err := digitalocean.NewProvider(ctx, logger, "cosmos-hub", doToken, []string{externalIP}, sshKeyPair) + if err != nil { + logger.Fatal("failed to create DigitalOcean provider", zap.Error(err)) + } + + chainConfig := petritypes.ChainConfig{ + Denom: "stake", + Decimals: 6, + NumValidators: 1, + NumNodes: 1, + BinaryName: "/usr/bin/simd", + Image: provider.ImageDefinition{ + Image: "interchainio/simapp:latest", + UID: "1000", + GID: "1000", + }, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + } + + chainOptions := petritypes.ChainOptions{ + NodeCreator: node.CreateNode, + NodeOptions: petritypes.NodeOptions{ + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig petritypes.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": imageID, + } + def.ProviderSpecificConfig = doConfig + return def + }, + }, + WalletConfig: petritypes.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, + } + + logger.Info("Creating chain") + cosmosChain, err := chain.CreateChain(ctx, logger, doProvider, chainConfig, chainOptions) + if err != nil { + logger.Fatal("failed to create chain", zap.Error(err)) + } + + logger.Info("Initializing chain") + err = cosmosChain.Init(ctx, chainOptions) + if err != nil { + logger.Fatal("failed to initialize chain", zap.Error(err)) + } + + logger.Info("Chain is successfully running! Waiting for chain to produce blocks") + err = cosmosChain.WaitForBlocks(ctx, 1) + if err != nil { + logger.Fatal("failed waiting for blocks", zap.Error(err)) + } + + // Comment out section below if you want to persist your Digital Ocean resources + logger.Info("Chain has successfully produced required number of blocks. Tearing down Digital Ocean resources.") + err = doProvider.Teardown(ctx) + if err != nil { + logger.Fatal("failed to teardown provider", zap.Error(err)) + } + + logger.Info("All Digital Ocean resources created have been successfully deleted!") +} + +func getExternalIP() (string, error) { + resp, err := http.Get("https://ifconfig.me") + if err != nil { + return "", err + } + defer resp.Body.Close() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return strings.TrimSpace(string(ip)), nil +} diff --git a/cosmos/go.mod b/cosmos/go.mod index bf0bca2..81aa656 100644 --- a/cosmos/go.mod +++ b/cosmos/go.mod @@ -15,6 +15,7 @@ require ( github.com/cometbft/cometbft v0.38.12 github.com/cosmos/cosmos-sdk v0.50.10 github.com/cosmos/go-bip39 v1.0.0 + github.com/docker/docker v27.1.1+incompatible github.com/golangci/golangci-lint v1.56.2 github.com/icza/dyno v0.0.0-20230330125955-09f820a8d9c0 github.com/matoous/go-nanoid/v2 v2.1.0 @@ -97,8 +98,8 @@ require ( github.com/dgraph-io/badger/v2 v2.2007.4 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 // indirect + github.com/digitalocean/godo v1.108.0 // indirect github.com/distribution/reference v0.5.0 // indirect - github.com/docker/docker v27.1.1+incompatible // indirect github.com/docker/go-connections v0.4.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -146,6 +147,7 @@ require ( github.com/golangci/unconvert v0.0.0-20180507085042-28b1c447d1f4 // indirect github.com/google/btree v1.1.2 // indirect github.com/google/go-cmp v0.6.0 // indirect + github.com/google/go-querystring v1.1.0 // indirect github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/gostaticanalysis/analysisutil v0.7.1 // indirect @@ -154,8 +156,10 @@ require ( github.com/gostaticanalysis/nilerr v0.1.1 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c // indirect + github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-immutable-radix v1.3.1 // indirect github.com/hashicorp/go-metrics v0.5.3 // indirect + github.com/hashicorp/go-retryablehttp v0.7.4 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/hashicorp/golang-lru v1.0.2 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -173,6 +177,7 @@ require ( github.com/kisielk/gotool v1.0.0 // indirect github.com/kkHAIKE/contextcheck v1.1.4 // indirect github.com/klauspost/compress v1.17.9 // indirect + github.com/kr/fs v0.1.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/kulti/thelper v0.6.3 // indirect @@ -209,6 +214,7 @@ require ( github.com/opencontainers/image-spec v1.1.0-rc2 // indirect github.com/petermattis/goid v0.0.0-20231207134359-e60b3f734c67 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pkg/sftp v1.13.6 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polyfloyd/go-errorlint v1.4.8 // indirect github.com/prometheus/client_golang v1.20.1 // indirect @@ -284,9 +290,11 @@ require ( golang.org/x/exp/typeparams v0.0.0-20231219180239-dc181d75b848 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/net v0.31.0 // indirect + golang.org/x/oauth2 v0.23.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.26.0 // indirect golang.org/x/text v0.20.0 // indirect + golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.22.0 // indirect google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241118233622-e639e219e697 // indirect diff --git a/cosmos/go.sum b/cosmos/go.sum index aade3e3..bcc9861 100644 --- a/cosmos/go.sum +++ b/cosmos/go.sum @@ -242,6 +242,8 @@ github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkz github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13 h1:fAjc9m62+UWV/WAFKLNi6ZS0675eEUC9y3AlwSbQu1Y= github.com/dgryski/go-farm v0.0.0-20200201041132-a6ae2369ad13/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= +github.com/digitalocean/godo v1.108.0 h1:fWyMENvtxpCpva1UbKzOFnyAS04N1FNuBWWfPeTGquQ= +github.com/digitalocean/godo v1.108.0/go.mod h1:R6EmmWI8CT1+fCtjWY9UCB+L5uufuZH13wk3YhxycCs= github.com/distribution/reference v0.5.0 h1:/FUIFXtfc/x2gpa5/VGfiGLuOIdYa1t65IKK2OFGvA0= github.com/distribution/reference v0.5.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/docker/docker v27.1.1+incompatible h1:hO/M4MtV36kzKldqnA37IWhebRA+LnqqcqDja6kVaKY= @@ -425,6 +427,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= +github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v0.0.0-20170612174753-24818f796faf/go.mod h1:HP5RmnzzSNb993RKQDq4+1A4ia9nllfqcQFTQJedwGI= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -478,6 +482,9 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.24.0/go.mod h1:qztMSjm835F2bXf+5HKA github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c h1:6rhixN/i8ZofjG1Y75iExal34USq5p+wiN1tpie8IrU= github.com/gsterjov/go-libsecret v0.0.0-20161001094733-a6f4afe4910c/go.mod h1:NMPJylDgVpX0MLRlPy15sqSwOFv/U1GZ2m21JhFfek0= github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= @@ -488,6 +495,8 @@ github.com/hashicorp/go-metrics v0.5.3/go.mod h1:KEjodfebIOuBYSAe/bHTm+HChmKSxAO github.com/hashicorp/go-plugin v1.5.2 h1:aWv8eimFqWlsEiMrYZdPYl+FdHaBJSN4AWwGWfT1G2Y= github.com/hashicorp/go-plugin v1.5.2/go.mod h1:w1sAEES3g3PuV/RzUrgow20W2uErMly84hhD3um1WL4= github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-retryablehttp v0.7.4 h1:ZQgVdpTdAL7WpMIwLzCfbalOcSUdkDZnpUv3/+BxzFA= +github.com/hashicorp/go-retryablehttp v0.7.4/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-uuid v1.0.2 h1:cfejS+Tpcp13yd5nYHWDI6qVCny6wyX2Mt5SGur2IGE= github.com/hashicorp/go-uuid v1.0.2/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= @@ -556,6 +565,8 @@ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2 github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.3/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= @@ -696,6 +707,8 @@ github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= +github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -1037,6 +1050,8 @@ golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4Iltr golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.23.0 h1:PbgcYx2W7i4LvjJWEbf0ngHV6qJYr86PkAV3bXdLEbs= +golang.org/x/oauth2 v0.23.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/cosmos/node/init.go b/cosmos/node/init.go index cd5d997..92675e2 100644 --- a/cosmos/node/init.go +++ b/cosmos/node/init.go @@ -12,6 +12,7 @@ func (n *Node) InitHome(ctx context.Context) error { n.logger.Info("initializing home", zap.String("name", n.GetDefinition().Name)) stdout, stderr, exitCode, err := n.RunCommand(ctx, n.BinCommand([]string{"init", n.GetDefinition().Name, "--chain-id", n.GetChainConfig().ChainId}...)) + n.logger.Debug("init home", zap.String("stdout", stdout), zap.String("stderr", stderr), zap.Int("exitCode", exitCode)) if err != nil { diff --git a/cosmos/tests/e2e/digitalocean/do_test.go b/cosmos/tests/e2e/digitalocean/do_test.go new file mode 100644 index 0000000..b2a2761 --- /dev/null +++ b/cosmos/tests/e2e/digitalocean/do_test.go @@ -0,0 +1,192 @@ +package e2e + +import ( + "context" + "flag" + "os" + "testing" + "time" + + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/digitalocean" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/skip-mev/petri/cosmos/v3/node" + "github.com/skip-mev/petri/cosmos/v3/tests/e2e" + + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var ( + defaultChainConfig = types.ChainConfig{ + Denom: "stake", + Decimals: 6, + NumValidators: 1, + NumNodes: 1, + BinaryName: "/usr/bin/simd", + Image: provider.ImageDefinition{ + Image: "interchainio/simapp:latest", + UID: "1000", + GID: "1000", + }, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + } + + defaultChainOptions = types.ChainOptions{ + NodeCreator: node.CreateNode, + WalletConfig: types.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, + NodeOptions: types.NodeOptions{ + NodeDefinitionModifier: func(def provider.TaskDefinition, nodeConfig types.NodeConfig) provider.TaskDefinition { + doConfig := digitalocean.DigitalOceanTaskConfig{ + "size": "s-2vcpu-4gb", + "region": "ams3", + "image_id": os.Getenv("DO_IMAGE_ID"), + } + def.ProviderSpecificConfig = doConfig + return def + }, + }, + } + + numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") + numNodes = flag.Int("num-nodes", 1, "number of nodes per chain") + numValidators = flag.Int("num-validators", 1, "number of validators per chain") +) + +func TestDOE2E(t *testing.T) { + if !flag.Parsed() { + flag.Parse() + } + + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + doToken := os.Getenv("DO_API_TOKEN") + if doToken == "" { + logger.Fatal("DO_API_TOKEN environment variable not set") + } + + imageID := os.Getenv("DO_IMAGE_ID") + if imageID == "" { + logger.Fatal("DO_IMAGE_ID environment variable not set") + } + + externalIP, err := e2e.GetExternalIP() + logger.Info("External IP", zap.String("address", externalIP)) + require.NoError(t, err) + + p, err := digitalocean.NewProvider(ctx, logger, "digitalocean_provider", doToken, []string{externalIP}, nil) + require.NoError(t, err) + + chains := make([]*cosmoschain.Chain, *numTestChains) + + // Create first half of chains + defaultChainConfig.NumNodes = *numNodes + defaultChainConfig.NumValidators = *numValidators + e2e.CreateChainsConcurrently(ctx, t, logger, p, 0, *numTestChains/2, chains, defaultChainConfig, defaultChainOptions) + + // Restore provider before creating second half of chains + serializedProvider, err := p.SerializeProvider(ctx) + require.NoError(t, err) + restoredProvider, err := digitalocean.RestoreProvider(ctx, doToken, serializedProvider, nil, nil) + require.NoError(t, err) + + // Restore the existing chains with the restored provider + restoredChains := make([]*cosmoschain.Chain, *numTestChains) + for i := 0; i < *numTestChains/2; i++ { + chainState, err := chains[i].Serialize(ctx, restoredProvider) + require.NoError(t, err) + + restoredChain, err := cosmoschain.RestoreChain(ctx, logger, restoredProvider, chainState, node.RestoreNode) + require.NoError(t, err) + + require.Equal(t, chains[i].GetConfig(), restoredChain.GetConfig()) + require.Equal(t, len(chains[i].GetValidators()), len(restoredChain.GetValidators())) + + restoredChains[i] = restoredChain + } + + // Create second half of chains with restored provider + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, restoredChains, defaultChainConfig, defaultChainOptions) + + // Test and teardown half the chains individually + for i := 0; i < *numTestChains/2; i++ { + originalChain := restoredChains[i] + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + + for _, validator := range validators { + e2e.AssertNodeRunning(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeRunning(t, ctx, node) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + + // Test individual chain teardown + err = originalChain.Teardown(ctx) + require.NoError(t, err) + + // wait for status to update on DO client side + time.Sleep(15 * time.Second) + + for _, validator := range validators { + e2e.AssertNodeShutdown(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeShutdown(t, ctx, node) + } + } + + // Test the remaining chains but let the provider teardown handle their cleanup + remainingChains := make([]*cosmoschain.Chain, 0) + for i := *numTestChains / 2; i < *numTestChains; i++ { + originalChain := restoredChains[i] + remainingChains = append(remainingChains, originalChain) + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + for _, validator := range validators { + e2e.AssertNodeRunning(t, ctx, validator) + } + for _, node := range nodes { + e2e.AssertNodeRunning(t, ctx, node) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + } + + require.NoError(t, restoredProvider.Teardown(ctx)) + // wait for status to update on DO client side + time.Sleep(15 * time.Second) + + // Verify all remaining chains are properly torn down + for _, chain := range remainingChains { + validators := chain.GetValidators() + nodes := chain.GetNodes() + + for _, validator := range validators { + e2e.AssertNodeShutdown(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeShutdown(t, ctx, node) + } + } +} diff --git a/cosmos/tests/e2e/docker/docker_test.go b/cosmos/tests/e2e/docker/docker_test.go new file mode 100644 index 0000000..4e58bc0 --- /dev/null +++ b/cosmos/tests/e2e/docker/docker_test.go @@ -0,0 +1,174 @@ +package e2e + +import ( + "context" + "flag" + "testing" + + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/provider/docker" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/skip-mev/petri/cosmos/v3/node" + "github.com/skip-mev/petri/cosmos/v3/tests/e2e" + + "github.com/cosmos/cosmos-sdk/crypto/hd" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +var ( + defaultChainConfig = types.ChainConfig{ + Denom: "stake", + Decimals: 6, + NumValidators: 1, + NumNodes: 1, + BinaryName: "/usr/bin/simd", + Image: provider.ImageDefinition{ + Image: "interchainio/simapp:latest", + UID: "1000", + GID: "1000", + }, + GasPrices: "0.0005stake", + Bech32Prefix: "cosmos", + HomeDir: "/gaia", + CoinType: "118", + ChainId: "stake-1", + UseGenesisSubCommand: false, + } + + defaultChainOptions = types.ChainOptions{ + NodeCreator: node.CreateNode, + WalletConfig: types.WalletConfig{ + SigningAlgorithm: string(hd.Secp256k1.Name()), + Bech32Prefix: "cosmos", + HDPath: hd.CreateHDPath(118, 0, 0), + DerivationFn: hd.Secp256k1.Derive(), + GenerationFn: hd.Secp256k1.Generate(), + }, + } + + numTestChains = flag.Int("num-chains", 3, "number of chains to create for concurrent testing") + numNodes = flag.Int("num-nodes", 1, "number of nodes per chain") + numValidators = flag.Int("num-validators", 1, "number of validators per chain") +) + +func TestDockerE2E(t *testing.T) { + if !flag.Parsed() { + flag.Parse() + } + + ctx := context.Background() + logger, _ := zap.NewDevelopment() + + defer func() { + dockerClient, err := client.NewClientWithOpts() + if err != nil { + t.Logf("Failed to create Docker client for volume cleanup: %v", err) + return + } + _, err = dockerClient.VolumesPrune(ctx, filters.Args{}) + if err != nil { + t.Logf("Failed to prune volumes: %v", err) + } + }() + + p, err := docker.CreateProvider(ctx, logger, "docker_provider") + require.NoError(t, err) + + chains := make([]*cosmoschain.Chain, *numTestChains) + + // Create first half of chains + defaultChainConfig.NumNodes = *numNodes + defaultChainConfig.NumValidators = *numValidators + e2e.CreateChainsConcurrently(ctx, t, logger, p, 0, *numTestChains/2, chains, defaultChainConfig, defaultChainOptions) + + // Restore provider before creating second half of chains + serializedProvider, err := p.SerializeProvider(ctx) + require.NoError(t, err) + restoredProvider, err := docker.RestoreProvider(ctx, logger, serializedProvider) + require.NoError(t, err) + + // Restore the existing chains with the restored provider + restoredChains := make([]*cosmoschain.Chain, *numTestChains) + for i := 0; i < *numTestChains/2; i++ { + chainState, err := chains[i].Serialize(ctx, restoredProvider) + require.NoError(t, err) + + restoredChain, err := cosmoschain.RestoreChain(ctx, logger, restoredProvider, chainState, node.RestoreNode) + require.NoError(t, err) + + require.Equal(t, chains[i].GetConfig(), restoredChain.GetConfig()) + require.Equal(t, len(chains[i].GetValidators()), len(restoredChain.GetValidators())) + + restoredChains[i] = restoredChain + } + + // Create second half of chains with restored provider + e2e.CreateChainsConcurrently(ctx, t, logger, restoredProvider, *numTestChains/2, *numTestChains, restoredChains, defaultChainConfig, defaultChainOptions) + + // Test and teardown half the chains individually + for i := 0; i < *numTestChains/2; i++ { + originalChain := restoredChains[i] + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + + for _, validator := range validators { + e2e.AssertNodeRunning(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeRunning(t, ctx, node) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + + // Test individual chain teardown + err = originalChain.Teardown(ctx) + require.NoError(t, err) + + for _, validator := range validators { + e2e.AssertNodeShutdown(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeShutdown(t, ctx, node) + } + } + + // Test the remaining chains but let the provider teardown handle their cleanup + remainingChains := make([]*cosmoschain.Chain, 0) + for i := *numTestChains / 2; i < *numTestChains; i++ { + originalChain := restoredChains[i] + remainingChains = append(remainingChains, originalChain) + validators := originalChain.GetValidators() + nodes := originalChain.GetNodes() + for _, validator := range validators { + e2e.AssertNodeRunning(t, ctx, validator) + } + for _, node := range nodes { + e2e.AssertNodeRunning(t, ctx, node) + } + + err = originalChain.WaitForBlocks(ctx, 2) + require.NoError(t, err) + } + + require.NoError(t, restoredProvider.Teardown(ctx)) + // Verify all remaining chains are properly torn down + for _, chain := range remainingChains { + validators := chain.GetValidators() + nodes := chain.GetNodes() + + for _, validator := range validators { + e2e.AssertNodeShutdown(t, ctx, validator) + } + + for _, node := range nodes { + e2e.AssertNodeShutdown(t, ctx, node) + } + } +} diff --git a/cosmos/tests/e2e/utils.go b/cosmos/tests/e2e/utils.go new file mode 100644 index 0000000..9f392c9 --- /dev/null +++ b/cosmos/tests/e2e/utils.go @@ -0,0 +1,97 @@ +package e2e + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "sync" + "testing" + + "github.com/skip-mev/petri/core/v3/provider" + "github.com/skip-mev/petri/core/v3/types" + cosmoschain "github.com/skip-mev/petri/cosmos/v3/chain" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func AssertNodeRunning(t *testing.T, ctx context.Context, node types.NodeI) { + status, err := node.GetStatus(ctx) + require.NoError(t, err) + require.Equal(t, provider.TASK_RUNNING, status) + + ip, err := node.GetIP(ctx) + require.NoError(t, err) + require.NotEmpty(t, ip) + + testFile := "test.txt" + testContent := []byte("test content") + err = node.WriteFile(ctx, testFile, testContent) + require.NoError(t, err) + + readContent, err := node.ReadFile(ctx, testFile) + require.NoError(t, err) + require.Equal(t, testContent, readContent) +} + +func AssertNodeShutdown(t *testing.T, ctx context.Context, node types.NodeI) { + status, err := node.GetStatus(ctx) + require.Error(t, err) + require.Equal(t, provider.TASK_STATUS_UNDEFINED, status, "node status should report as undefined after shutdown") + + _, err = node.GetIP(ctx) + require.Error(t, err, "node IP should not be accessible after teardown") +} + +func GetExternalIP() (string, error) { + resp, err := http.Get("https://ifconfig.me") + if err != nil { + return "", err + } + defer resp.Body.Close() + + ip, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return strings.TrimSpace(string(ip)), nil +} + +// CreateChainsConcurrently creates multiple chains concurrently using the provided configuration +func CreateChainsConcurrently( + ctx context.Context, + t *testing.T, + logger *zap.Logger, + p provider.ProviderI, + startIndex, endIndex int, + chains []*cosmoschain.Chain, + chainConfig types.ChainConfig, + chainOptions types.ChainOptions, +) { + var wg sync.WaitGroup + chainErrors := make(chan error, endIndex-startIndex) + + for i := startIndex; i < endIndex; i++ { + wg.Add(1) + go func(index int) { + defer wg.Done() + config := chainConfig + config.ChainId = fmt.Sprintf("chain-%d", index) + c, err := cosmoschain.CreateChain(ctx, logger, p, config, chainOptions) + if err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to create chain %d: %w", index, err) + return + } + if err := c.Init(ctx, chainOptions); err != nil { + t.Logf("Chain creation error: %v", err) + chainErrors <- fmt.Errorf("failed to init chain %d: %w", index, err) + return + } + chains[index] = c + }(i) + } + wg.Wait() + require.Empty(t, chainErrors) +}