Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(docker/network): harden docker networking #163

Merged
merged 8 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/provider/docker/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ func (p *Provider) openListenerOnFreePort() (*net.TCPListener, error) {
return nil, err
}

p.networkMu.Lock()
defer p.networkMu.Unlock()
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, err
Expand All @@ -110,6 +108,10 @@ func (p *Provider) openListenerOnFreePort() (*net.TCPListener, error) {
// This allows multiple nextAvailablePort calls to find multiple available ports
// before closing them so they are available for the PortBinding.
func (p *Provider) nextAvailablePort() (nat.PortBinding, *net.TCPListener, error) {
// TODO: add listeners to state
p.networkMu.Lock()
defer p.networkMu.Unlock()

l, err := p.openListenerOnFreePort()
if err != nil {
if l != nil {
Expand Down
37 changes: 24 additions & 13 deletions core/provider/docker/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type ProviderState struct {

Name string `json:"name"`

NetworkID string `json:"network_id"`
NetworkName string `json:"network_name"`
NetworkCIDR string `json:"network_cidr"`
AllocatedIPs []string `json:"allocated_ips"`
NetworkID string `json:"network_id"`
NetworkName string `json:"network_name"`
NetworkCIDR string `json:"network_cidr"`
NetworkGateway string `json:"network_gateway"`

BuilderImageName string `json:"builder_image_name"`
}
Expand Down Expand Up @@ -86,17 +86,22 @@ func CreateProvider(ctx context.Context, logger *zap.Logger, providerName string
}

dockerProvider.state.NetworkCIDR = cidrMask.String()
dockerProvider.state.NetworkGateway = network.IPAM.Config[0].Gateway

dockerProvider.dockerNetworkAllocator, err = ipallocator.NewCIDRRange(cidrMask)

if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(network.IPAM.Config[0].Gateway)); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a guarantee that network.IPAM.Config[0] exists, i.e. that the array is not nil?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

return nil, fmt.Errorf("failed to allocate gateway ip: %w", err)
}

if err != nil {
return nil, err
}

return dockerProvider, nil
}

func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
func RestoreProvider(ctx context.Context, logger *zap.Logger, state []byte) (*Provider, error) {
var providerState ProviderState

err := json.Unmarshal(state, &providerState)
Expand All @@ -106,7 +111,8 @@ func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
}

dockerProvider := &Provider{
state: &providerState,
state: &providerState,
logger: logger,
}

dockerClient, err := client.NewClientWithOpts()
Expand All @@ -133,8 +139,12 @@ func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
return nil, fmt.Errorf("failed to create ip allocator from state: %w", err)
}

for _, ip := range providerState.AllocatedIPs {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(ip)); err != nil {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(providerState.NetworkGateway)); err != nil {
return nil, fmt.Errorf("failed to allocate gateway ip: %w", err)
}

for _, task := range providerState.TaskStates {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(task.IpAddress)); err != nil {
return nil, fmt.Errorf("failed to restore ip allocator state: %w", err)
}
}
Expand Down Expand Up @@ -223,12 +233,12 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin
Labels: map[string]string{
providerLabelName: p.state.Name,
},
Env: convertEnvMapToList(definition.Environment),
Env: convertEnvMapToList(definition.Environment),
ExposedPorts: portSet,
}, &container.HostConfig{
Mounts: mounts,
PortBindings: portBindings,
PublishAllPorts: true,
NetworkMode: container.NetworkMode(p.state.NetworkName),
Mounts: mounts,
PortBindings: portBindings,
NetworkMode: container.NetworkMode(p.state.NetworkName),
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
p.state.NetworkName: {
Expand All @@ -245,6 +255,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin

taskState.Id = createdContainer.ID
taskState.Status = provider.TASK_STOPPED
taskState.IpAddress = ip

p.stateMu.Lock()
defer p.stateMu.Unlock()
Expand Down
4 changes: 2 additions & 2 deletions core/provider/docker/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestRestoreProvider(t *testing.T) {
serialized, err := p1.SerializeProvider(ctx)
require.NoError(t, err)

p2, err := docker.RestoreProvider(ctx, serialized)
p2, err := docker.RestoreProvider(ctx, logger, serialized)
require.NoError(t, err)

state2 := p2.GetState()
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestProviderSerialization(t *testing.T) {
serialized, err := p1.SerializeProvider(ctx)
require.NoError(t, err)

p2, err := docker.RestoreProvider(ctx, serialized)
p2, err := docker.RestoreProvider(ctx, logger, serialized)
require.NoError(t, err)

state2 := p2.GetState()
Expand Down
8 changes: 1 addition & 7 deletions core/provider/docker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,13 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err
return "", fmt.Errorf("failed to inspect container: %w", err)
}

ip, err := t.GetIP(ctx)

if err != nil {
return "", fmt.Errorf("failed to get IP: %w", err)
}

portBindings, ok := dockerContainer.NetworkSettings.Ports[nat.Port(fmt.Sprintf("%s/tcp", port))]

if !ok || len(portBindings) == 0 {
return "", fmt.Errorf("port %s not found", port)
}

return fmt.Sprintf("%s:%s", ip, portBindings[0].HostPort), nil
return fmt.Sprintf("0.0.0.0:%s", portBindings[0].HostPort), nil
}

func (t *Task) GetIP(ctx context.Context) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion core/provider/docker/util.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package docker

import (
"fmt"
"github.com/docker/go-connections/nat"

"github.com/skip-mev/petri/core/v2/provider"
Expand All @@ -10,7 +11,7 @@ func convertTaskDefinitionPortsToPortSet(definition provider.TaskDefinition) nat
bindings := nat.PortSet{}

for _, port := range definition.Ports {
bindings[nat.Port(port)] = struct{}{}
bindings[nat.Port(fmt.Sprintf("%s/tcp", port))] = struct{}{}
}

return bindings
Expand Down
Loading