Skip to content

Commit

Permalink
Make sure that every rtr version has a different session_id
Browse files Browse the repository at this point in the history
The version-session_id-serial touple defines the cache state.
When a client connects with a different version its cache is no
longer in sync and this is the simplest way to enforce this.
  • Loading branch information
cjeker committed Aug 6, 2024
1 parent 6276927 commit aaf389b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 29 deletions.
11 changes: 3 additions & 8 deletions cmd/stayrtr/stayrtr.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@ var errRPKIJsonFileTooOld = errors.New("RPKI JSON file is older than 24 hours")

// Update the state based on the current slurm file and data.
func (s *state) updateFromNewState() error {
sessid := s.server.GetSessionId()

vrpsjson := s.lastdata.ROA
if vrpsjson == nil {
return nil
Expand Down Expand Up @@ -391,13 +389,11 @@ func (s *state) updateFromNewState() error {
count := len(vrps) + len(brks) + len(vaps)

log.Infof("New update (%v uniques, %v total prefixes, %v vaps, %v router keys).", len(vrps), count, len(vaps), len(brks))
return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
}

// Update the state based on the currently loaded files
func (s *state) reloadFromCurrentState() error {
sessid := s.server.GetSessionId()

vrpsjson := s.lastdata.ROA
if vrpsjson == nil {
return nil
Expand Down Expand Up @@ -434,13 +430,12 @@ func (s *state) reloadFromCurrentState() error {
count := len(vrps) + len(brks) + len(vaps)
if s.server.CountSDs() != count {
log.Infof("New update to old state (%v uniques, %v total prefixes). (old %v - new %v)", len(vrps), count, s.server.CountSDs(), count)
return s.applyUpdateFromNewState(vrps, brks, vaps, sessid, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
return s.applyUpdateFromNewState(vrps, brks, vaps, vrpsjson, bgpsecjson, aspajson, countv4, countv6)
}
return nil
}

func (s *state) applyUpdateFromNewState(vrps []rtr.VRP, brks []rtr.BgpsecKey, vaps []rtr.VAP,
sessid uint16,
vrpsjson []prefixfile.VRPJson, brksjson []prefixfile.BgpSecKeyJson, aspajson []prefixfile.VAPJson,
countv4 int, countv6 int) error {

Expand Down Expand Up @@ -852,7 +847,7 @@ func run() error {

if *Bind != "" {
go func() {
sessid := server.GetSessionId()
sessid := server.GetSessionId(protoverToLib[*RTRVersion])
log.Infof("StayRTR Server started (sessionID:%d, refresh:%d, retry:%d, expire:%d)", sessid, sc.RefreshInterval, sc.RetryInterval, sc.ExpireInterval)
err := server.Start(*Bind)
if err != nil {
Expand Down
39 changes: 18 additions & 21 deletions lib/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ type SendableData interface {

// This handles things like ROAs, BGPsec Router keys, ASPA info etc
type SendableDataManager interface {
GetCurrentSerial(uint16) (uint32, bool)
GetSessionId() uint16
GetCurrentSerial() (uint32, bool)
GetSessionId(uint8) uint16
GetCurrentSDs() ([]SendableData, bool)
GetSDsSerialDiff(uint32) ([]SendableData, bool)
}
Expand All @@ -63,8 +63,8 @@ func (e *DefaultRTREventHandler) RequestCache(c *Client) {
if e.Log != nil {
e.Log.Debugf("%v > Request Cache", c)
}
sessionId := e.sdManager.GetSessionId()
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
sessionId := e.sdManager.GetSessionId(c.GetVersion())
serial, valid := e.sdManager.GetCurrentSerial()
if !valid {
c.SendNoDataError()
if e.Log != nil {
Expand All @@ -90,7 +90,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16,
if e.Log != nil {
e.Log.Debugf("%v > Request New Version", c)
}
serverSessionId := e.sdManager.GetSessionId()
serverSessionId := e.sdManager.GetSessionId(c.GetVersion())
if sessionId != serverSessionId {
c.SendCorruptData()
if e.Log != nil {
Expand All @@ -99,7 +99,7 @@ func (e *DefaultRTREventHandler) RequestNewVersion(c *Client, sessionId uint16,
c.Disconnect()
return
}
serial, valid := e.sdManager.GetCurrentSerial(sessionId)
serial, valid := e.sdManager.GetCurrentSerial()
if !valid {
c.SendNoDataError()
if e.Log != nil {
Expand All @@ -125,7 +125,7 @@ type Server struct {
baseVersion uint8
clientlock *sync.RWMutex
clients []*Client
sessId uint16
sessId []uint16
connected int
maxconn int

Expand Down Expand Up @@ -166,7 +166,11 @@ type ServerConfiguration struct {
}

func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler, simpleHandler RTREventHandler) *Server {
sessid := GenerateSessionId()
sessids := make([]uint16, 0, int(configuration.ProtocolVersion) + 1)
s := GenerateSessionId()
for i := 0; i <= int(configuration.ProtocolVersion); i++ {
sessids = append(sessids, s + uint16(100 * i))
}

refreshInterval := uint32(3600)
if configuration.RefreshInterval != 0 {
Expand All @@ -189,7 +193,7 @@ func NewServer(configuration ServerConfiguration, handler RTRServerEventHandler,

clientlock: &sync.RWMutex{},
clients: make([]*Client, 0),
sessId: sessid,
sessId: sessids,
maxconn: configuration.MaxConn,
baseVersion: configuration.ProtocolVersion,
enforceVersion: configuration.EnforceVersion,
Expand Down Expand Up @@ -277,8 +281,8 @@ func ApplyDiff(diff, prevSDs []SendableData) []SendableData {
return newSDs
}

func (s *Server) GetSessionId() uint16 {
return s.sessId
func (s *Server) GetSessionId(version uint8) uint16 {
return s.sessId[version]
}

func (s *Server) GetCurrentSDs() ([]SendableData, bool) {
Expand Down Expand Up @@ -311,7 +315,7 @@ func (s *Server) getSDsSerialDiff(serial uint32) ([]SendableData, bool) {
return sd, true
}

func (s *Server) GetCurrentSerial(sessId uint16) (uint32, bool) {
func (s *Server) GetCurrentSerial() (uint32, bool) {
s.sdlock.RLock()
serial, valid := s.getCurrentSerial()
s.sdlock.RUnlock()
Expand Down Expand Up @@ -408,10 +412,6 @@ func (s *Server) GetMaxConnections() int {
return s.maxconn
}

func (s *Server) SetSessionId(sessId uint16) {
s.sessId = sessId
}

func (s *Server) ClientConnected(c *Client) {
s.clientlock.Lock()
s.clients = append(s.clients, c)
Expand Down Expand Up @@ -629,14 +629,11 @@ func (s *Server) GetClientList() []*Client {
}

func (s *Server) NotifyClientsLatest() {
serial, _ := s.GetCurrentSerial(s.sessId)
s.NotifyClients(serial)
}
serial, _ := s.GetCurrentSerial()

func (s *Server) NotifyClients(serialNumber uint32) {
clients := s.GetClientList()
for _, c := range clients {
c.Notify(s.sessId, serialNumber)
c.Notify(s.GetSessionId(c.GetVersion()), serial)
}
}

Expand Down

0 comments on commit aaf389b

Please sign in to comment.