Skip to content

Commit

Permalink
refactor(storage): add store.Store interface
Browse files Browse the repository at this point in the history
There is a first implementation with ValKey that will allow to use redis APIs as a backend for Sablier with Hight Availability
  • Loading branch information
acouvreur committed Feb 2, 2025
1 parent c768224 commit da7e332
Show file tree
Hide file tree
Showing 17 changed files with 589 additions and 168 deletions.
1 change: 1 addition & 0 deletions .testcontainers.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ryuk.disabled=true
13 changes: 8 additions & 5 deletions app/discovery/autostop.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package discovery

import (
"context"
"errors"
"github.com/sablierapp/sablier/app/providers"
"github.com/sablierapp/sablier/pkg/arrays"
"github.com/sablierapp/sablier/pkg/store"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
)
Expand All @@ -12,7 +13,7 @@ import (
// as running instances by Sablier.
// By default, Sablier does not stop all already running instances. Meaning that you need to make an
// initial request in order to trigger the scaling to zero.
func StopAllUnregisteredInstances(ctx context.Context, provider providers.Provider, registered []string) error {
func StopAllUnregisteredInstances(ctx context.Context, provider providers.Provider, s store.Store) error {
log.Info("Stopping all unregistered running instances")

log.Tracef("Retrieving all instances with label [%v=true]", LabelEnable)
Expand All @@ -25,12 +26,14 @@ func StopAllUnregisteredInstances(ctx context.Context, provider providers.Provid
}

log.Tracef("Found %v instances with label [%v=true]", len(instances), LabelEnable)
names := make([]string, 0, len(instances))
unregistered := make([]string, 0)
for _, instance := range instances {
names = append(names, instance.Name)
_, err = s.Get(ctx, instance.Name)
if errors.Is(err, store.ErrKeyNotFound) {
unregistered = append(unregistered, instance.Name)
}
}

unregistered := arrays.RemoveElements(names, registered)
log.Tracef("Found %v unregistered instances ", len(instances))

waitGroup := errgroup.Group{}
Expand Down
24 changes: 14 additions & 10 deletions app/discovery/autostop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ import (
"context"
"errors"
"github.com/sablierapp/sablier/app/discovery"
"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/providers"
"github.com/sablierapp/sablier/app/providers/mock"
"github.com/sablierapp/sablier/app/types"
"github.com/sablierapp/sablier/pkg/store/inmemory"
"gotest.tools/v3/assert"
"testing"
"time"
)

func TestStopAllUnregisteredInstances(t *testing.T) {
Expand All @@ -20,7 +24,9 @@ func TestStopAllUnregisteredInstances(t *testing.T) {
{Name: "instance2"},
{Name: "instance3"},
}
registered := []string{"instance1"}
store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute)
assert.NilError(t, err)

// Set up expectations for InstanceList
mockProvider.On("InstanceList", ctx, providers.InstanceListOptions{
Expand All @@ -33,10 +39,8 @@ func TestStopAllUnregisteredInstances(t *testing.T) {
mockProvider.On("Stop", ctx, "instance3").Return(nil)

// Call the function under test
err := discovery.StopAllUnregisteredInstances(ctx, mockProvider, registered)
if err != nil {
t.Fatalf("Expected no error, but got %v", err)
}
err = discovery.StopAllUnregisteredInstances(ctx, mockProvider, store)
assert.NilError(t, err)

// Check expectations
mockProvider.AssertExpectations(t)
Expand All @@ -52,7 +56,9 @@ func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
{Name: "instance2"},
{Name: "instance3"},
}
registered := []string{"instance1"}
store := inmemory.NewInMemory()
err := store.Put(ctx, instance.State{Name: "instance1"}, time.Minute)
assert.NilError(t, err)

// Set up expectations for InstanceList
mockProvider.On("InstanceList", ctx, providers.InstanceListOptions{
Expand All @@ -65,10 +71,8 @@ func TestStopAllUnregisteredInstances_WithError(t *testing.T) {
mockProvider.On("Stop", ctx, "instance3").Return(nil)

// Call the function under test
err := discovery.StopAllUnregisteredInstances(ctx, mockProvider, registered)
if err == nil {
t.Fatalf("Expected error, but got nil")
}
err = discovery.StopAllUnregisteredInstances(ctx, mockProvider, store)
assert.Error(t, err, "stop error")

// Check expectations
mockProvider.AssertExpectations(t)
Expand Down
54 changes: 46 additions & 8 deletions app/sablier.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import (
"github.com/sablierapp/sablier/app/providers/docker"
"github.com/sablierapp/sablier/app/providers/dockerswarm"
"github.com/sablierapp/sablier/app/providers/kubernetes"
"github.com/sablierapp/sablier/pkg/store/inmemory"
"log/slog"
"os"
"time"

"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/providers"
"github.com/sablierapp/sablier/app/sessions"
"github.com/sablierapp/sablier/app/storage"
"github.com/sablierapp/sablier/app/theme"
"github.com/sablierapp/sablier/config"
"github.com/sablierapp/sablier/internal/server"
"github.com/sablierapp/sablier/pkg/tinykv"
"github.com/sablierapp/sablier/version"
log "github.com/sirupsen/logrus"
)
Expand All @@ -45,7 +45,11 @@ func Start(ctx context.Context, conf config.Config) error {

log.Infof("using provider \"%s\"", conf.Provider.Name)

store := tinykv.New(conf.Sessions.ExpirationInterval, onSessionExpires(provider))
store := inmemory.NewInMemory()
err = store.OnExpire(ctx, onSessionExpires(provider))
if err != nil {
return err
}

storage, err := storage.NewFileStorage(conf.Storage)
if err != nil {
Expand All @@ -55,13 +59,30 @@ func Start(ctx context.Context, conf config.Config) error {
sessionsManager := sessions.NewSessionsManager(store, provider)
defer sessionsManager.Stop()

updateGroups := make(chan map[string][]string)
go WatchGroups(ctx, provider, 2*time.Second, updateGroups)
go func() {
for groups := range updateGroups {
sessionsManager.SetGroups(groups)
}
}()

instanceStopped := make(chan string)
go provider.NotifyInstanceStopped(ctx, instanceStopped)
go func() {
for stopped := range instanceStopped {
err := sessionsManager.RemoveInstance(stopped)
logger.Warn("could not remove instance", slog.Any("error", err))
}
}()

if storage.Enabled() {
defer saveSessions(storage, sessionsManager)
loadSessions(storage, sessionsManager)
}

if conf.Provider.AutoStopOnStartup {
err := discovery.StopAllUnregisteredInstances(context.Background(), provider, store.Keys())
err := discovery.StopAllUnregisteredInstances(context.Background(), provider, store)
if err != nil {
log.Warnf("Stopping unregistered instances had an error: %v", err)
}
Expand Down Expand Up @@ -96,9 +117,9 @@ func Start(ctx context.Context, conf config.Config) error {
return nil
}

func onSessionExpires(provider providers.Provider) func(key string, instance instance.State) {
return func(_key string, _instance instance.State) {
go func(key string, instance instance.State) {
func onSessionExpires(provider providers.Provider) func(key string) {
return func(_key string) {
go func(key string) {
log.Debugf("stopping %s...", key)
err := provider.Stop(context.Background(), key)

Expand All @@ -107,7 +128,7 @@ func onSessionExpires(provider providers.Provider) func(key string, instance ins
} else {
log.Debugf("stopped %s", key)
}
}(_key, _instance)
}(_key)
}
}

Expand Down Expand Up @@ -149,3 +170,20 @@ func NewProvider(config config.Provider) (providers.Provider, error) {
}
return nil, fmt.Errorf("unimplemented provider %s", config.Name)
}

func WatchGroups(ctx context.Context, provider providers.Provider, frequency time.Duration, send chan<- map[string][]string) {
ticker := time.NewTicker(frequency)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
groups, err := provider.GetGroups(ctx)
if err != nil {
log.Warn("could not get groups", err)
} else {
send <- groups
}
}
}
}
27 changes: 0 additions & 27 deletions app/sessions/groups_watcher.go

This file was deleted.

71 changes: 30 additions & 41 deletions app/sessions/sessions_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/sablierapp/sablier/pkg/store"
"io"
"log/slog"
"maps"
"slices"
"sync"
"time"

"github.com/sablierapp/sablier/app/instance"
"github.com/sablierapp/sablier/app/providers"
"github.com/sablierapp/sablier/pkg/tinykv"
log "github.com/sirupsen/logrus"
)

const defaultRefreshFrequency = 2 * time.Second

//go:generate mockgen -package sessionstest -source=sessions_manager.go -destination=sessionstest/mocks_sessions_manager.go *

type Manager interface {
Expand All @@ -30,19 +29,22 @@ type Manager interface {
LoadSessions(io.ReadCloser) error
SaveSessions(io.WriteCloser) error

RemoveInstance(name string) error
SetGroups(groups map[string][]string)

Stop()
}

type SessionsManager struct {
ctx context.Context
cancel context.CancelFunc

store tinykv.KV[instance.State]
store store.Store
provider providers.Provider
groups map[string][]string
}

func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Provider) Manager {
func NewSessionsManager(store store.Store, provider providers.Provider) Manager {
ctx, cancel := context.WithCancel(context.Background())

groups, err := provider.GetGroups(ctx)
Expand All @@ -59,49 +61,37 @@ func NewSessionsManager(store tinykv.KV[instance.State], provider providers.Prov
groups: groups,
}

sm.initWatchers()

return sm
}

func (sm *SessionsManager) initWatchers() {
updateGroups := make(chan map[string][]string)
go watchGroups(sm.ctx, sm.provider, defaultRefreshFrequency, updateGroups)
go sm.consumeGroups(updateGroups)

instanceStopped := make(chan string)
go sm.provider.NotifyInstanceStopped(sm.ctx, instanceStopped)
go sm.consumeInstanceStopped(instanceStopped)
func (sm *SessionsManager) SetGroups(groups map[string][]string) {
sm.groups = groups
}

func (sm *SessionsManager) consumeGroups(receive chan map[string][]string) {
for groups := range receive {
sm.groups = groups
}
}

func (sm *SessionsManager) consumeInstanceStopped(instanceStopped chan string) {
for instance := range instanceStopped {
// Will delete from the store containers that have been stop either by external sources
// or by the internal expiration loop, if the deleted entry does not exist, it doesn't matter
log.Debugf("received event instance %s is stopped, removing from store", instance)
sm.store.Delete(instance)
}
func (sm *SessionsManager) RemoveInstance(name string) error {
return sm.store.Delete(context.Background(), name)
}

func (sm *SessionsManager) LoadSessions(reader io.ReadCloser) error {
unmarshaler, ok := sm.store.(json.Unmarshaler)
defer reader.Close()
return json.NewDecoder(reader).Decode(sm.store)
if ok {
return json.NewDecoder(reader).Decode(unmarshaler)
}
return nil
}

func (sm *SessionsManager) SaveSessions(writer io.WriteCloser) error {
marshaler, ok := sm.store.(json.Marshaler)
defer writer.Close()
if ok {
encoder := json.NewEncoder(writer)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")

encoder := json.NewEncoder(writer)
encoder.SetEscapeHTML(false)
encoder.SetIndent("", " ")

return encoder.Encode(sm.store)
return encoder.Encode(marshaler)
}
return nil
}

type InstanceState struct {
Expand Down Expand Up @@ -190,9 +180,8 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
return nil, errors.New("instance name cannot be empty")
}

requestState, exists := s.store.Get(name)

if !exists {
requestState, err := s.store.Get(context.TODO(), name)
if errors.Is(err, store.ErrKeyNotFound) {
log.Debugf("starting [%s]...", name)

err := s.provider.Start(s.ctx, name)
Expand All @@ -212,6 +201,8 @@ func (s *SessionsManager) requestSessionInstance(name string, duration time.Dura
requestState.Message = state.Message

log.Debugf("status for [%s]=[%s]", name, requestState.Status)
} else if err != nil {
return nil, fmt.Errorf("cannot retrieve instance from store: %w", err)
} else if requestState.Status != instance.Ready {
log.Debugf("checking [%s]...", name)
state, err := s.provider.GetState(s.ctx, name)
Expand Down Expand Up @@ -306,15 +297,13 @@ func (s *SessionsManager) RequestReadySessionGroup(ctx context.Context, group st
}

func (s *SessionsManager) ExpiresAfter(instance *instance.State, duration time.Duration) {
s.store.Put(instance.Name, *instance, duration)
err := s.store.Put(context.TODO(), *instance, duration)
slog.Default().Warn("could not put instance to store, will not expire", slog.Any("error", err), slog.String("instance", instance.Name))
}

func (s *SessionsManager) Stop() {
// Stop event listeners
s.cancel()

// Stop the store
s.store.Stop()
}

func (s *SessionState) MarshalJSON() ([]byte, error) {
Expand Down
Loading

0 comments on commit da7e332

Please sign in to comment.