Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(docker/network): harden docker networking #163

Open
wants to merge 7 commits into
base: release/v3.x.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/provider/docker/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ func (p *Provider) openListenerOnFreePort() (*net.TCPListener, error) {
return nil, err
}

p.networkMu.Lock()
defer p.networkMu.Unlock()
l, err := net.ListenTCP("tcp", addr)
if err != nil {
return nil, err
Expand All @@ -110,6 +108,10 @@ func (p *Provider) openListenerOnFreePort() (*net.TCPListener, error) {
// This allows multiple nextAvailablePort calls to find multiple available ports
// before closing them so they are available for the PortBinding.
func (p *Provider) nextAvailablePort() (nat.PortBinding, *net.TCPListener, error) {
// TODO: add listeners to state
p.networkMu.Lock()
defer p.networkMu.Unlock()

l, err := p.openListenerOnFreePort()
if err != nil {
if l != nil {
Expand Down
37 changes: 24 additions & 13 deletions core/provider/docker/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type ProviderState struct {

Name string `json:"name"`

NetworkID string `json:"network_id"`
NetworkName string `json:"network_name"`
NetworkCIDR string `json:"network_cidr"`
AllocatedIPs []string `json:"allocated_ips"`
NetworkID string `json:"network_id"`
NetworkName string `json:"network_name"`
NetworkCIDR string `json:"network_cidr"`
NetworkGateway string `json:"network_gateway"`

BuilderImageName string `json:"builder_image_name"`
}
Expand Down Expand Up @@ -86,17 +86,22 @@ func CreateProvider(ctx context.Context, logger *zap.Logger, providerName string
}

dockerProvider.state.NetworkCIDR = cidrMask.String()
dockerProvider.state.NetworkGateway = network.IPAM.Config[0].Gateway

dockerProvider.dockerNetworkAllocator, err = ipallocator.NewCIDRRange(cidrMask)

if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(network.IPAM.Config[0].Gateway)); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have a guarantee that network.IPAM.Config[0] exists, i.e. that the array is not nil?

return nil, fmt.Errorf("failed to allocate gateway ip: %w", err)
}

if err != nil {
return nil, err
}

return dockerProvider, nil
}

func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
func RestoreProvider(ctx context.Context, logger *zap.Logger, state []byte) (*Provider, error) {
var providerState ProviderState

err := json.Unmarshal(state, &providerState)
Expand All @@ -106,7 +111,8 @@ func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
}

dockerProvider := &Provider{
state: &providerState,
state: &providerState,
logger: logger,
}

dockerClient, err := client.NewClientWithOpts()
Expand All @@ -133,8 +139,12 @@ func RestoreProvider(ctx context.Context, state []byte) (*Provider, error) {
return nil, fmt.Errorf("failed to create ip allocator from state: %w", err)
}

for _, ip := range providerState.AllocatedIPs {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(ip)); err != nil {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(providerState.NetworkGateway)); err != nil {
return nil, fmt.Errorf("failed to allocate gateway ip: %w", err)
}

for _, task := range providerState.TaskStates {
if err := dockerProvider.dockerNetworkAllocator.Allocate(net.ParseIP(task.IpAddress)); err != nil {
return nil, fmt.Errorf("failed to restore ip allocator state: %w", err)
}
}
Expand Down Expand Up @@ -223,12 +233,12 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin
Labels: map[string]string{
providerLabelName: p.state.Name,
},
Env: convertEnvMapToList(definition.Environment),
Env: convertEnvMapToList(definition.Environment),
ExposedPorts: portSet,
}, &container.HostConfig{
Mounts: mounts,
PortBindings: portBindings,
PublishAllPorts: true,
NetworkMode: container.NetworkMode(p.state.NetworkName),
Mounts: mounts,
PortBindings: portBindings,
NetworkMode: container.NetworkMode(p.state.NetworkName),
}, &network.NetworkingConfig{
EndpointsConfig: map[string]*network.EndpointSettings{
p.state.NetworkName: {
Expand All @@ -245,6 +255,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin

taskState.Id = createdContainer.ID
taskState.Status = provider.TASK_STOPPED
taskState.IpAddress = ip

p.stateMu.Lock()
defer p.stateMu.Unlock()
Expand Down
4 changes: 2 additions & 2 deletions core/provider/docker/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func TestRestoreProvider(t *testing.T) {
serialized, err := p1.SerializeProvider(ctx)
require.NoError(t, err)

p2, err := docker.RestoreProvider(ctx, serialized)
p2, err := docker.RestoreProvider(ctx, logger, serialized)
require.NoError(t, err)

state2 := p2.GetState()
Expand Down Expand Up @@ -265,7 +265,7 @@ func TestProviderSerialization(t *testing.T) {
serialized, err := p1.SerializeProvider(ctx)
require.NoError(t, err)

p2, err := docker.RestoreProvider(ctx, serialized)
p2, err := docker.RestoreProvider(ctx, logger, serialized)
require.NoError(t, err)

state2 := p2.GetState()
Expand Down
10 changes: 2 additions & 8 deletions core/provider/docker/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type TaskState struct {
Volume *VolumeState `json:"volumes"`
Definition provider.TaskDefinition `json:"definition"`
Status provider.TaskStatus `json:"status"`
IpAddress string `json:"ip_address"`
IpAddress string `json:"ip_address"`
}

type VolumeState struct {
Expand Down Expand Up @@ -106,19 +106,13 @@ func (t *Task) GetExternalAddress(ctx context.Context, port string) (string, err
return "", fmt.Errorf("failed to inspect container: %w", err)
}

ip, err := t.GetIP(ctx)

if err != nil {
return "", fmt.Errorf("failed to get IP: %w", err)
}

portBindings, ok := dockerContainer.NetworkSettings.Ports[nat.Port(fmt.Sprintf("%s/tcp", port))]

if !ok || len(portBindings) == 0 {
return "", fmt.Errorf("port %s not found", port)
}

return fmt.Sprintf("%s:%s", ip, portBindings[0].HostPort), nil
return fmt.Sprintf("0.0.0.0:%s", portBindings[0].HostPort), nil
}

func (t *Task) GetIP(ctx context.Context) (string, error) {
Expand Down
3 changes: 2 additions & 1 deletion core/provider/docker/util.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package docker

import (
"fmt"
"github.com/docker/go-connections/nat"

"github.com/skip-mev/petri/core/v2/provider"
Expand All @@ -10,7 +11,7 @@ func convertTaskDefinitionPortsToPortSet(definition provider.TaskDefinition) nat
bindings := nat.PortSet{}

for _, port := range definition.Ports {
bindings[nat.Port(port)] = struct{}{}
bindings[nat.Port(fmt.Sprintf("%s/tcp", port))] = struct{}{}
}

return bindings
Expand Down
34 changes: 23 additions & 11 deletions core/types/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type GenesisModifier func([]byte) ([]byte, error)

// ChainI is an interface for a logical chain
type ChainI interface {
Init(context.Context) error
Init(context.Context, ChainOptions) error
Teardown(context.Context) error

GetConfig() ChainConfig
Expand All @@ -32,6 +32,28 @@ type ChainI interface {
Height(context.Context) (uint64, error)
WaitForBlocks(ctx context.Context, delta uint64) error
WaitForHeight(ctx context.Context, desiredHeight uint64) error

Serialize(ctx context.Context, p provider.ProviderI) ([]byte, error)
}

type ChainOptions struct {
ModifyGenesis GenesisModifier // ModifyGenesis is a function that modifies the genesis bytes of the chain
NodeOptions NodeOptions // NodeOptions is the options for creating a node
NodeCreator NodeCreator // NodeCreator is a function that creates a node

WalletConfig WalletConfig // WalletConfig is the default configuration of a chain's wallet
}

func (o ChainOptions) ValidateBasic() error {
if err := o.WalletConfig.ValidateBasic(); err != nil {
return fmt.Errorf("wallet config is invalid: %w", err)
}

if o.NodeCreator == nil {
return fmt.Errorf("node creator cannot be nil")
}

return nil
}

// ChainConfig is the configuration structure for a logical chain.
Expand All @@ -55,14 +77,8 @@ type ChainConfig struct {
CoinType string // CoinType is the coin type of the chain (e.g. 118)
ChainId string // ChainId is the chain ID of the chain

ModifyGenesis GenesisModifier // ModifyGenesis is a function that modifies the genesis bytes of the chain

WalletConfig WalletConfig // WalletConfig is the default configuration of a chain's wallet

UseGenesisSubCommand bool // UseGenesisSubCommand is a flag that indicates whether to use the 'genesis' subcommand to initialize the chain. Set to true if Cosmos SDK >v0.50

NodeCreator NodeCreator // NodeCreator is a function that creates a node
NodeDefinitionModifier NodeDefinitionModifier // NodeDefinitionModifier is a function that modifies a node's definition
// number of tokens to allocate per account in the genesis state (unscaled). This value defaults to 10_000_000 if not set.
// if not set.
GenesisDelegation *big.Int
Expand Down Expand Up @@ -121,9 +137,5 @@ func (c *ChainConfig) ValidateBasic() error {
return fmt.Errorf("chain ID cannot be empty")
}

if c.NodeCreator == nil {
return fmt.Errorf("node creator cannot be nil")
}

return nil
}
13 changes: 12 additions & 1 deletion core/types/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
"github.com/skip-mev/petri/core/v2/provider"
)

// NodeOptions is a struct that contains the options for creating a node
type NodeOptions struct {
NodeDefinitionModifier NodeDefinitionModifier // NodeDefinitionModifier is a function that modifies a node's definition
}

// NodeConfig is the configuration structure for a logical node.
type NodeConfig struct {
Name string // Name is the name of the node
Expand Down Expand Up @@ -40,7 +45,10 @@ func (c NodeConfig) ValidateBasic() error {
type NodeDefinitionModifier func(provider.TaskDefinition, NodeConfig) provider.TaskDefinition

// NodeCreator is a type of function that given a NodeConfig creates a new logical node
type NodeCreator func(context.Context, *zap.Logger, provider.ProviderI, NodeConfig) (NodeI, error)
type NodeCreator func(context.Context, *zap.Logger, provider.ProviderI, NodeConfig, NodeOptions) (NodeI, error)

// NodeRestorer is a type of function that given a NodeState restores a logical node
type NodeRestorer func(context.Context, *zap.Logger, []byte, provider.ProviderI) (NodeI, error)

// NodeI represents an interface for a logical node that is running on a chain
type NodeI interface {
Expand Down Expand Up @@ -95,4 +103,7 @@ type NodeI interface {

// GetIP returns the IP address of the node
GetIP(context.Context) (string, error)

// Serialize serializes the node
Serialize(context.Context, provider.ProviderI) ([]byte, error)
}
29 changes: 28 additions & 1 deletion core/types/wallet.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package types

import "github.com/cosmos/cosmos-sdk/crypto/hd"
import (
"fmt"
"github.com/cosmos/cosmos-sdk/crypto/hd"
)

// WalletConfig is a configuration for a Cosmos SDK type wallet
type WalletConfig struct {
Expand All @@ -10,3 +13,27 @@ type WalletConfig struct {
HDPath *hd.BIP44Params // HDPath is the default HD path to use for deriving keys
SigningAlgorithm string // SigningAlgorithm is the default signing algorithm to use
}

func (c WalletConfig) ValidateBasic() error {
if c.DerivationFn == nil {
return fmt.Errorf("derivation function cannot be nil")
}

if c.GenerationFn == nil {
return fmt.Errorf("generation function cannot be nil")
}

if c.Bech32Prefix == "" {
return fmt.Errorf("bech32 prefix cannot be empty")
}

if c.HDPath == nil {
return fmt.Errorf("HD path cannot be nil")
}

if c.SigningAlgorithm == "" {
return fmt.Errorf("signing algorithm cannot be empty")
}

return nil
}
Loading
Loading