Skip to content

Commit

Permalink
Merge pull request #183 from cloudstruct/feat/protocol-resource-cleanup
Browse files Browse the repository at this point in the history
feat: cleanup protocol resources on shutdown
  • Loading branch information
agaffney authored Jan 24, 2023
2 parents 1ec6209 + 4432835 commit a98d6dd
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 39 deletions.
25 changes: 14 additions & 11 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
BINARY=go-ouroboros-network

# Determine root directory
ROOT_DIR=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST))))

# Gather all .go files for use in dependencies below
GO_FILES=$(shell find $(ROOT_DIR) -name '*.go')

# Build our program binary
# Depends on GO_FILES to determine when rebuild is needed
$(BINARY): $(GO_FILES)
# Needed to fetch new dependencies and add them to go.mod
go mod tidy
go build -o $(BINARY) ./cmd/$(BINARY)
# Gather list of expected binaries
BINARIES=$(shell cd $(ROOT_DIR)/cmd && ls -1)

.PHONY: build clean test
.PHONY: build mod-tidy clean test

# Alias for building program binary
build: $(BINARY)
build: $(BINARIES)

mod-tidy:
# Needed to fetch new dependencies and add them to go.mod
go mod tidy

clean:
rm -f $(BINARY)
rm -f $(BINARIES)

test:
go test -v ./...

# Build our program binaries
# Depends on GO_FILES to determine when rebuild is needed
$(BINARIES): mod-tidy $(GO_FILES)
go build -o $(@) ./cmd/$(@)
2 changes: 2 additions & 0 deletions cmd/go-ouroboros-network/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func main() {
testServer(f)
case "query":
testQuery(f)
case "mem-usage":
testMemUsage(f)
default:
fmt.Printf("Unknown subcommand: %s\n", f.flagset.Arg(0))
os.Exit(1)
Expand Down
110 changes: 110 additions & 0 deletions cmd/go-ouroboros-network/mem_usage.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package main

import (
"flag"
"fmt"
"log"
"net/http"
_ "net/http/pprof"
"os"
"runtime"
"runtime/pprof"
"time"

ouroboros "github.com/cloudstruct/go-ouroboros-network"
)

type memUsageFlags struct {
flagset *flag.FlagSet
startEra string
tip bool
debugPort int
}

func newMemUsageFlags() *memUsageFlags {
f := &memUsageFlags{
flagset: flag.NewFlagSet("mem-usage", flag.ExitOnError),
}
f.flagset.StringVar(&f.startEra, "start-era", "genesis", "era which to start chain-sync at")
f.flagset.BoolVar(&f.tip, "tip", false, "start chain-sync at current chain tip")
f.flagset.IntVar(&f.debugPort, "debug-port", 8080, "pprof port")
return f
}

func testMemUsage(f *globalFlags) {
memUsageFlags := newMemUsageFlags()
err := memUsageFlags.flagset.Parse(f.flagset.Args()[1:])
if err != nil {
fmt.Printf("failed to parse subcommand args: %s\n", err)
os.Exit(1)
}

// Start pprof listener
log.Printf("Starting pprof listener on http://0.0.0.0:%d/debug/pprof\n", memUsageFlags.debugPort)
go func() {
log.Println(http.ListenAndServe(fmt.Sprintf(":%d", memUsageFlags.debugPort), nil))
}()

for i := 0; i < 10; i++ {
showMemoryStats("open")

conn := createClientConnection(f)
errorChan := make(chan error)
go func() {
for {
err, ok := <-errorChan
if !ok {
return
}
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}
}()
o, err := ouroboros.New(
ouroboros.WithConnection(conn),
ouroboros.WithNetworkMagic(uint32(f.networkMagic)),
ouroboros.WithErrorChan(errorChan),
ouroboros.WithNodeToNode(f.ntnProto),
ouroboros.WithKeepAlive(true),
)
if err != nil {
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}
o.ChainSync.Client.Start()

tip, err := o.ChainSync.Client.GetCurrentTip()
if err != nil {
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}

log.Printf("tip: slot = %d, hash = %x\n", tip.Point.Slot, tip.Point.Hash)

if err := o.Close(); err != nil {
fmt.Printf("ERROR: %s\n", err)
}

showMemoryStats("close")

time.Sleep(5 * time.Second)

runtime.GC()

showMemoryStats("after GC")
}

if err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1); err != nil {
fmt.Printf("ERROR: %s\n", err)
os.Exit(1)
}

fmt.Printf("waiting forever")
select {}
}

func showMemoryStats(tag string) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
log.Printf("[%s] HeapAlloc: %dKiB\n", tag, m.HeapAlloc/1024)
}
4 changes: 2 additions & 2 deletions muxer/muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (m *Muxer) sendError(err error) {
m.Stop()
}

func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) {
func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment, chan bool) {
// Generate channels
senderChan := make(chan *Segment, 10)
receiverChan := make(chan *Segment, 10)
Expand All @@ -118,7 +118,7 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segmen
}
}
}()
return senderChan, receiverChan
return senderChan, receiverChan, m.doneChan
}

func (m *Muxer) Send(msg *Segment) error {
Expand Down
28 changes: 16 additions & 12 deletions ouroboros.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,20 +132,24 @@ func (o *Ouroboros) setupConnection() error {
o.muxer = muxer.New(o.conn)
// Start Goroutine to pass along errors from the muxer
go func() {
err, ok := <-o.muxer.ErrorChan
// Break out of goroutine if muxer's error channel is closed
if !ok {
select {
case <-o.doneChan:
return
case err, ok := <-o.muxer.ErrorChan:
// Break out of goroutine if muxer's error channel is closed
if !ok {
return
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
// Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF
o.ErrorChan <- io.EOF
} else {
// Wrap error message to denote it comes from the muxer
o.ErrorChan <- fmt.Errorf("muxer error: %s", err)
}
// Close connection on muxer errors
o.Close()
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
// Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF
o.ErrorChan <- io.EOF
} else {
// Wrap error message to denote it comes from the muxer
o.ErrorChan <- fmt.Errorf("muxer error: %s", err)
}
// Close connection on muxer errors
o.Close()
}()
protoOptions := protocol.ProtocolOptions{
Muxer: o.muxer,
Expand Down
7 changes: 7 additions & 0 deletions protocol/chainsync/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
InitialState: STATE_IDLE,
}
c.Protocol = protocol.New(protoConfig)
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.Protocol.DoneChan()
close(c.intersectResultChan)
close(c.readyForNextBlockChan)
close(c.currentTipChan)
}()
return c
}

Expand Down
10 changes: 10 additions & 0 deletions protocol/keepalive/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
InitialState: STATE_CLIENT,
}
c.Protocol = protocol.New(protoConfig)
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.Protocol.DoneChan()
if c.timer != nil {
// Stop timer and drain channel
if ok := c.timer.Stop(); !ok {
<-c.timer.C
}
}
}()
return c
}

Expand Down
6 changes: 6 additions & 0 deletions protocol/localstatequery/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
c.enableGetRewardInfoPoolsBlock = true
}
c.Protocol = protocol.New(protoConfig)
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.Protocol.DoneChan()
close(c.queryResultChan)
close(c.acquireResultChan)
}()
return c
}

Expand Down
5 changes: 5 additions & 0 deletions protocol/localtxsubmission/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client {
InitialState: STATE_IDLE,
}
c.Protocol = protocol.New(protoConfig)
// Start goroutine to cleanup resources on protocol shutdown
go func() {
<-c.Protocol.DoneChan()
close(c.submitResultChan)
}()
return c
}

Expand Down
55 changes: 41 additions & 14 deletions protocol/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Protocol struct {
config ProtocolConfig
muxerSendChan chan *muxer.Segment
muxerRecvChan chan *muxer.Segment
muxerDoneChan chan bool
state State
stateMutex sync.Mutex
recvBuffer *bytes.Buffer
Expand Down Expand Up @@ -77,21 +78,29 @@ type MessageFromCborFunc func(uint, []byte) (Message, error)

func New(config ProtocolConfig) *Protocol {
p := &Protocol{
config: config,
config: config,
doneChan: make(chan bool),
}
return p
}

func (p *Protocol) Start() {
// Register protocol with muxer
p.muxerSendChan, p.muxerRecvChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId)
// Create buffers and channels
p.recvBuffer = bytes.NewBuffer(nil)
p.sendQueueChan = make(chan Message, 50)
p.sendStateQueueChan = make(chan Message, 50)
p.recvReadyChan = make(chan bool, 1)
p.sendReadyChan = make(chan bool, 1)
p.doneChan = make(chan bool)
// Start goroutine to cleanup when shutting down
go func() {
<-p.doneChan
close(p.sendQueueChan)
close(p.sendStateQueueChan)
close(p.recvReadyChan)
close(p.sendReadyChan)
}()
// Set initial state
p.setState(p.config.InitialState)
// Start our send and receive Goroutines
Expand All @@ -107,6 +116,10 @@ func (p *Protocol) Role() ProtocolRole {
return p.config.Role
}

func (p *Protocol) DoneChan() chan bool {
return p.doneChan
}

func (p *Protocol) SendMessage(msg Message) error {
p.sendQueueChan <- msg
return nil
Expand All @@ -122,14 +135,14 @@ func (p *Protocol) sendLoop() {
var err error
for {
select {
case <-p.sendReadyChan:
// We are ready to send based on state map
case <-p.doneChan:
// We are responsible for closing this channel as the sender, even through it
// was created by the muxer
close(p.muxerSendChan)
// Break out of send loop if we're shutting down
return
case <-p.sendReadyChan:
// We are ready to send based on state map
}
// Lock the state to prevent collisions
p.stateMutex.Lock()
Expand All @@ -155,7 +168,11 @@ func (p *Protocol) sendLoop() {
msgCount := 0
for {
// Get next message from send queue
msg := <-p.sendQueueChan
msg, ok := <-p.sendQueueChan
if !ok {
// We're shutting down
return
}
msgCount = msgCount + 1
// Write the message into the send state queue if we already have a new state
if setNewState {
Expand Down Expand Up @@ -234,20 +251,29 @@ func (p *Protocol) recvLoop() {
// Don't grab the next segment from the muxer if we still have data in the buffer
if !leftoverData {
// Wait for segment
segment, ok := <-p.muxerRecvChan
// Break out of receive loop if channel is closed
if !ok {
select {
case <-p.muxerDoneChan:
close(p.doneChan)
return
case segment, ok := <-p.muxerRecvChan:
if !ok {
close(p.doneChan)
return
}
// Add segment payload to buffer
p.recvBuffer.Write(segment.Payload)
// Save whether it's a response
isResponse = segment.IsResponse()
}
// Add segment payload to buffer
p.recvBuffer.Write(segment.Payload)
// Save whether it's a response
isResponse = segment.IsResponse()
}
leftoverData = false
// Wait until ready to receive based on state map
<-p.recvReadyChan
select {
case <-p.muxerDoneChan:
close(p.doneChan)
return
case <-p.recvReadyChan:
}
// Decode message into generic list until we can determine what type of message it is.
// This also lets us determine how many bytes the message is. We use RawMessage here to
// avoid parsing things that we may not be able to parse
Expand Down Expand Up @@ -321,6 +347,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) {
func (p *Protocol) setState(state State) {
// Disable any previous state transition timer
if p.stateTransitionTimer != nil {
// Stop timer and drain channel
if !p.stateTransitionTimer.Stop() {
<-p.stateTransitionTimer.C
}
Expand Down

0 comments on commit a98d6dd

Please sign in to comment.