Skip to content

Commit

Permalink
move serializable fields to state
Browse files Browse the repository at this point in the history
  • Loading branch information
nadim-az committed Jan 16, 2025
1 parent e34e2a2 commit cb1fc71
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 36 deletions.
6 changes: 3 additions & 3 deletions core/provider/digitalocean/droplet.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,18 @@ func (p *Provider) CreateDroplet(ctx context.Context, definition provider.TaskDe
}

req := &godo.DropletCreateRequest{
Name: fmt.Sprintf("%s-%s", p.petriTag, definition.Name),
Name: fmt.Sprintf("%s-%s", p.state.petriTag, definition.Name),
Region: doConfig["region"],
Size: doConfig["size"],
Image: godo.DropletCreateImage{
ID: int(imageId),
},
SSHKeys: []godo.DropletCreateSSHKey{
{
Fingerprint: p.sshKeyPair.Fingerprint,
Fingerprint: p.state.sshKeyPair.Fingerprint,
},
},
Tags: []string{p.petriTag},
Tags: []string{p.state.petriTag},
}

droplet, res, err := p.doClient.CreateDroplet(ctx, req)
Expand Down
8 changes: 4 additions & 4 deletions core/provider/digitalocean/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (

func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*godo.Firewall, error) {
req := &godo.FirewallRequest{
Name: fmt.Sprintf("%s-firewall", p.petriTag),
Tags: []string{p.petriTag},
Name: fmt.Sprintf("%s-firewall", p.state.petriTag),
Tags: []string{p.state.petriTag},
OutboundRules: []godo.OutboundRule{
{
Protocol: "tcp",
Expand Down Expand Up @@ -39,15 +39,15 @@ func (p *Provider) createFirewall(ctx context.Context, allowedIPs []string) (*go
Protocol: "tcp",
PortRange: "0",
Sources: &godo.Sources{
Tags: []string{p.petriTag},
Tags: []string{p.state.petriTag},
Addresses: allowedIPs,
},
},
{
Protocol: "udp",
PortRange: "0",
Sources: &godo.Sources{
Tags: []string{p.petriTag},
Tags: []string{p.state.petriTag},
Addresses: allowedIPs,
},
},
Expand Down
39 changes: 20 additions & 19 deletions core/provider/digitalocean/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,18 @@ const (
type ProviderState struct {
TaskStates map[int]*TaskState `json:"task_states"` // map of task ids to the corresponding task state
Name string `json:"name"`
petriTag string
userIPs []string
sshKeyPair *SSHKeyPair
firewallID string
}

type Provider struct {
state *ProviderState
stateMu sync.Mutex

logger *zap.Logger
name string
doClient DoClient
petriTag string
userIPs []string
sshKeyPair *SSHKeyPair
firewallID string
dockerClients map[string]DockerClient // map of droplet ip address to docker clients
}

Expand Down Expand Up @@ -73,16 +72,18 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName

digitalOceanProvider := &Provider{
logger: logger.Named("digitalocean_provider"),
name: providerName,
doClient: doClient,
petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)),
userIPs: userIPs,
sshKeyPair: sshKeyPair,
dockerClients: dockerClients,
state: &ProviderState{TaskStates: make(map[int]*TaskState)},
state: &ProviderState{
TaskStates: make(map[int]*TaskState),
userIPs: userIPs,
Name: providerName,
sshKeyPair: sshKeyPair,
petriTag: fmt.Sprintf("petri-droplet-%s", util.RandomString(5)),
},
}

_, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.petriTag)
_, err = digitalOceanProvider.createTag(ctx, digitalOceanProvider.state.petriTag)
if err != nil {
return nil, err
}
Expand All @@ -92,7 +93,7 @@ func NewProviderWithClient(ctx context.Context, logger *zap.Logger, providerName
return nil, fmt.Errorf("failed to create firewall: %w", err)
}

digitalOceanProvider.firewallID = firewall.ID
digitalOceanProvider.state.firewallID = firewall.ID

//TODO(Zygimantass): TOCTOU issue
if key, _, err := doClient.GetKeyByFingerprint(ctx, sshKeyPair.Fingerprint); err != nil || key == nil {
Expand Down Expand Up @@ -163,7 +164,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin
Tty: false,
Hostname: definition.Name,
Labels: map[string]string{
providerLabelName: p.name,
providerLabelName: p.state.Name,
},
Env: convertEnvMapToList(definition.Environment),
}, &container.HostConfig{
Expand All @@ -185,7 +186,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin
Name: definition.Name,
Definition: definition,
Status: provider.TASK_STOPPED,
ProviderName: p.name,
ProviderName: p.state.Name,
}

p.stateMu.Lock()
Expand All @@ -196,7 +197,7 @@ func (p *Provider) CreateTask(ctx context.Context, definition provider.TaskDefin
return &Task{
state: taskState,
provider: p,
sshKeyPair: p.sshKeyPair,
sshKeyPair: p.state.sshKeyPair,
logger: p.logger.With(zap.String("task", definition.Name)),
doClient: p.doClient,
dockerClient: p.dockerClients[ip],
Expand Down Expand Up @@ -271,7 +272,7 @@ func (p *Provider) Teardown(ctx context.Context) error {
}

func (p *Provider) teardownTasks(ctx context.Context) error {
res, err := p.doClient.DeleteDropletByTag(ctx, p.petriTag)
res, err := p.doClient.DeleteDropletByTag(ctx, p.state.petriTag)
if err != nil {
return err
}
Expand All @@ -284,7 +285,7 @@ func (p *Provider) teardownTasks(ctx context.Context) error {
}

func (p *Provider) teardownFirewall(ctx context.Context) error {
res, err := p.doClient.DeleteFirewall(ctx, p.firewallID)
res, err := p.doClient.DeleteFirewall(ctx, p.state.firewallID)
if err != nil {
return err
}
Expand All @@ -297,7 +298,7 @@ func (p *Provider) teardownFirewall(ctx context.Context) error {
}

func (p *Provider) teardownSSHKey(ctx context.Context) error {
res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.sshKeyPair.Fingerprint)
res, err := p.doClient.DeleteKeyByFingerprint(ctx, p.state.sshKeyPair.Fingerprint)
if err != nil {
return err
}
Expand All @@ -310,7 +311,7 @@ func (p *Provider) teardownSSHKey(ctx context.Context) error {
}

func (p *Provider) teardownTag(ctx context.Context) error {
res, err := p.doClient.DeleteTag(ctx, p.petriTag)
res, err := p.doClient.DeleteTag(ctx, p.state.petriTag)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion core/provider/digitalocean/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func getUserIPs(ctx context.Context) (ips []string, err error) {
}

func (p *Provider) createSSHKey(ctx context.Context, pubKey string) (*godo.Key, error) {
req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.petriTag)}
req := &godo.KeyCreateRequest{PublicKey: pubKey, Name: fmt.Sprintf("%s-key", p.state.petriTag)}

key, res, err := p.doClient.CreateKey(ctx, req)
if err != nil {
Expand Down
9 changes: 0 additions & 9 deletions core/types/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,5 @@ func (c ChainConfig) ValidateBasic() error {
return fmt.Errorf("node creator cannot be nil")
}

if !isValidWalletConfig(c.WalletConfig) {
return fmt.Errorf("invalid wallet config")
}

return nil
}

func isValidWalletConfig(cfg WalletConfig) bool {
return cfg.Bech32Prefix != "" && cfg.SigningAlgorithm != "" &&
cfg.HDPath != nil && cfg.DerivationFn != nil && cfg.GenerationFn != nil
}

0 comments on commit cb1fc71

Please sign in to comment.