diff --git a/Makefile b/Makefile index 7f38639b..8aed3bc4 100644 --- a/Makefile +++ b/Makefile @@ -1,25 +1,28 @@ -BINARY=go-ouroboros-network - # Determine root directory ROOT_DIR=$(shell dirname $(realpath $(firstword $(MAKEFILE_LIST)))) # Gather all .go files for use in dependencies below GO_FILES=$(shell find $(ROOT_DIR) -name '*.go') -# Build our program binary -# Depends on GO_FILES to determine when rebuild is needed -$(BINARY): $(GO_FILES) - # Needed to fetch new dependencies and add them to go.mod - go mod tidy - go build -o $(BINARY) ./cmd/$(BINARY) +# Gather list of expected binaries +BINARIES=$(shell cd $(ROOT_DIR)/cmd && ls -1) -.PHONY: build clean test +.PHONY: build mod-tidy clean test # Alias for building program binary -build: $(BINARY) +build: $(BINARIES) + +mod-tidy: + # Needed to fetch new dependencies and add them to go.mod + go mod tidy clean: - rm -f $(BINARY) + rm -f $(BINARIES) test: go test -v ./... + +# Build our program binaries +# Depends on GO_FILES to determine when rebuild is needed +$(BINARIES): mod-tidy $(GO_FILES) + go build -o $(@) ./cmd/$(@) diff --git a/cmd/go-ouroboros-network/main.go b/cmd/go-ouroboros-network/main.go index 88380cbe..554fa60d 100644 --- a/cmd/go-ouroboros-network/main.go +++ b/cmd/go-ouroboros-network/main.go @@ -78,6 +78,8 @@ func main() { testServer(f) case "query": testQuery(f) + case "mem-usage": + testMemUsage(f) default: fmt.Printf("Unknown subcommand: %s\n", f.flagset.Arg(0)) os.Exit(1) diff --git a/cmd/go-ouroboros-network/mem_usage.go b/cmd/go-ouroboros-network/mem_usage.go new file mode 100644 index 00000000..b520a3e2 --- /dev/null +++ b/cmd/go-ouroboros-network/mem_usage.go @@ -0,0 +1,110 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + _ "net/http/pprof" + "os" + "runtime" + "runtime/pprof" + "time" + + ouroboros "github.com/cloudstruct/go-ouroboros-network" +) + +type memUsageFlags struct { + flagset *flag.FlagSet + startEra string + tip bool + debugPort int +} + +func newMemUsageFlags() *memUsageFlags { + f := &memUsageFlags{ + flagset: flag.NewFlagSet("mem-usage", flag.ExitOnError), + } + f.flagset.StringVar(&f.startEra, "start-era", "genesis", "era which to start chain-sync at") + f.flagset.BoolVar(&f.tip, "tip", false, "start chain-sync at current chain tip") + f.flagset.IntVar(&f.debugPort, "debug-port", 8080, "pprof port") + return f +} + +func testMemUsage(f *globalFlags) { + memUsageFlags := newMemUsageFlags() + err := memUsageFlags.flagset.Parse(f.flagset.Args()[1:]) + if err != nil { + fmt.Printf("failed to parse subcommand args: %s\n", err) + os.Exit(1) + } + + // Start pprof listener + log.Printf("Starting pprof listener on http://0.0.0.0:%d/debug/pprof\n", memUsageFlags.debugPort) + go func() { + log.Println(http.ListenAndServe(fmt.Sprintf(":%d", memUsageFlags.debugPort), nil)) + }() + + for i := 0; i < 10; i++ { + showMemoryStats("open") + + conn := createClientConnection(f) + errorChan := make(chan error) + go func() { + for { + err, ok := <-errorChan + if !ok { + return + } + fmt.Printf("ERROR: %s\n", err) + os.Exit(1) + } + }() + o, err := ouroboros.New( + ouroboros.WithConnection(conn), + ouroboros.WithNetworkMagic(uint32(f.networkMagic)), + ouroboros.WithErrorChan(errorChan), + ouroboros.WithNodeToNode(f.ntnProto), + ouroboros.WithKeepAlive(true), + ) + if err != nil { + fmt.Printf("ERROR: %s\n", err) + os.Exit(1) + } + o.ChainSync.Client.Start() + + tip, err := o.ChainSync.Client.GetCurrentTip() + if err != nil { + fmt.Printf("ERROR: %s\n", err) + os.Exit(1) + } + + log.Printf("tip: slot = %d, hash = %x\n", tip.Point.Slot, tip.Point.Hash) + + if err := o.Close(); err != nil { + fmt.Printf("ERROR: %s\n", err) + } + + showMemoryStats("close") + + time.Sleep(5 * time.Second) + + runtime.GC() + + showMemoryStats("after GC") + } + + if err := pprof.Lookup("goroutine").WriteTo(os.Stdout, 1); err != nil { + fmt.Printf("ERROR: %s\n", err) + os.Exit(1) + } + + fmt.Printf("waiting forever") + select {} +} + +func showMemoryStats(tag string) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + log.Printf("[%s] HeapAlloc: %dKiB\n", tag, m.HeapAlloc/1024) +} diff --git a/muxer/muxer.go b/muxer/muxer.go index 1911c407..626d784f 100644 --- a/muxer/muxer.go +++ b/muxer/muxer.go @@ -94,7 +94,7 @@ func (m *Muxer) sendError(err error) { m.Stop() } -func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment) { +func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segment, chan bool) { // Generate channels senderChan := make(chan *Segment, 10) receiverChan := make(chan *Segment, 10) @@ -118,7 +118,7 @@ func (m *Muxer) RegisterProtocol(protocolId uint16) (chan *Segment, chan *Segmen } } }() - return senderChan, receiverChan + return senderChan, receiverChan, m.doneChan } func (m *Muxer) Send(msg *Segment) error { diff --git a/ouroboros.go b/ouroboros.go index 05c289bb..bfe91a5e 100644 --- a/ouroboros.go +++ b/ouroboros.go @@ -132,20 +132,24 @@ func (o *Ouroboros) setupConnection() error { o.muxer = muxer.New(o.conn) // Start Goroutine to pass along errors from the muxer go func() { - err, ok := <-o.muxer.ErrorChan - // Break out of goroutine if muxer's error channel is closed - if !ok { + select { + case <-o.doneChan: return + case err, ok := <-o.muxer.ErrorChan: + // Break out of goroutine if muxer's error channel is closed + if !ok { + return + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + // Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF + o.ErrorChan <- io.EOF + } else { + // Wrap error message to denote it comes from the muxer + o.ErrorChan <- fmt.Errorf("muxer error: %s", err) + } + // Close connection on muxer errors + o.Close() } - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - // Return a bare io.EOF error if error is EOF/ErrUnexpectedEOF - o.ErrorChan <- io.EOF - } else { - // Wrap error message to denote it comes from the muxer - o.ErrorChan <- fmt.Errorf("muxer error: %s", err) - } - // Close connection on muxer errors - o.Close() }() protoOptions := protocol.ProtocolOptions{ Muxer: o.muxer, diff --git a/protocol/chainsync/client.go b/protocol/chainsync/client.go index c6343cbb..fda148c6 100644 --- a/protocol/chainsync/client.go +++ b/protocol/chainsync/client.go @@ -63,6 +63,13 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { InitialState: STATE_IDLE, } c.Protocol = protocol.New(protoConfig) + // Start goroutine to cleanup resources on protocol shutdown + go func() { + <-c.Protocol.DoneChan() + close(c.intersectResultChan) + close(c.readyForNextBlockChan) + close(c.currentTipChan) + }() return c } diff --git a/protocol/keepalive/client.go b/protocol/keepalive/client.go index e877c071..7f9ff607 100644 --- a/protocol/keepalive/client.go +++ b/protocol/keepalive/client.go @@ -40,6 +40,16 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { InitialState: STATE_CLIENT, } c.Protocol = protocol.New(protoConfig) + // Start goroutine to cleanup resources on protocol shutdown + go func() { + <-c.Protocol.DoneChan() + if c.timer != nil { + // Stop timer and drain channel + if ok := c.timer.Stop(); !ok { + <-c.timer.C + } + } + }() return c } diff --git a/protocol/localstatequery/client.go b/protocol/localstatequery/client.go index b941bf3e..1faa0ff8 100644 --- a/protocol/localstatequery/client.go +++ b/protocol/localstatequery/client.go @@ -65,6 +65,12 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { c.enableGetRewardInfoPoolsBlock = true } c.Protocol = protocol.New(protoConfig) + // Start goroutine to cleanup resources on protocol shutdown + go func() { + <-c.Protocol.DoneChan() + close(c.queryResultChan) + close(c.acquireResultChan) + }() return c } diff --git a/protocol/localtxsubmission/client.go b/protocol/localtxsubmission/client.go index dbd9db0a..04b01796 100644 --- a/protocol/localtxsubmission/client.go +++ b/protocol/localtxsubmission/client.go @@ -43,6 +43,11 @@ func NewClient(protoOptions protocol.ProtocolOptions, cfg *Config) *Client { InitialState: STATE_IDLE, } c.Protocol = protocol.New(protoConfig) + // Start goroutine to cleanup resources on protocol shutdown + go func() { + <-c.Protocol.DoneChan() + close(c.submitResultChan) + }() return c } diff --git a/protocol/protocol.go b/protocol/protocol.go index dbbd685c..ae365dd6 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -22,6 +22,7 @@ type Protocol struct { config ProtocolConfig muxerSendChan chan *muxer.Segment muxerRecvChan chan *muxer.Segment + muxerDoneChan chan bool state State stateMutex sync.Mutex recvBuffer *bytes.Buffer @@ -77,21 +78,29 @@ type MessageFromCborFunc func(uint, []byte) (Message, error) func New(config ProtocolConfig) *Protocol { p := &Protocol{ - config: config, + config: config, + doneChan: make(chan bool), } return p } func (p *Protocol) Start() { // Register protocol with muxer - p.muxerSendChan, p.muxerRecvChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId) + p.muxerSendChan, p.muxerRecvChan, p.muxerDoneChan = p.config.Muxer.RegisterProtocol(p.config.ProtocolId) // Create buffers and channels p.recvBuffer = bytes.NewBuffer(nil) p.sendQueueChan = make(chan Message, 50) p.sendStateQueueChan = make(chan Message, 50) p.recvReadyChan = make(chan bool, 1) p.sendReadyChan = make(chan bool, 1) - p.doneChan = make(chan bool) + // Start goroutine to cleanup when shutting down + go func() { + <-p.doneChan + close(p.sendQueueChan) + close(p.sendStateQueueChan) + close(p.recvReadyChan) + close(p.sendReadyChan) + }() // Set initial state p.setState(p.config.InitialState) // Start our send and receive Goroutines @@ -107,6 +116,10 @@ func (p *Protocol) Role() ProtocolRole { return p.config.Role } +func (p *Protocol) DoneChan() chan bool { + return p.doneChan +} + func (p *Protocol) SendMessage(msg Message) error { p.sendQueueChan <- msg return nil @@ -122,14 +135,14 @@ func (p *Protocol) sendLoop() { var err error for { select { - case <-p.sendReadyChan: - // We are ready to send based on state map case <-p.doneChan: // We are responsible for closing this channel as the sender, even through it // was created by the muxer close(p.muxerSendChan) // Break out of send loop if we're shutting down return + case <-p.sendReadyChan: + // We are ready to send based on state map } // Lock the state to prevent collisions p.stateMutex.Lock() @@ -155,7 +168,11 @@ func (p *Protocol) sendLoop() { msgCount := 0 for { // Get next message from send queue - msg := <-p.sendQueueChan + msg, ok := <-p.sendQueueChan + if !ok { + // We're shutting down + return + } msgCount = msgCount + 1 // Write the message into the send state queue if we already have a new state if setNewState { @@ -234,20 +251,29 @@ func (p *Protocol) recvLoop() { // Don't grab the next segment from the muxer if we still have data in the buffer if !leftoverData { // Wait for segment - segment, ok := <-p.muxerRecvChan - // Break out of receive loop if channel is closed - if !ok { + select { + case <-p.muxerDoneChan: close(p.doneChan) return + case segment, ok := <-p.muxerRecvChan: + if !ok { + close(p.doneChan) + return + } + // Add segment payload to buffer + p.recvBuffer.Write(segment.Payload) + // Save whether it's a response + isResponse = segment.IsResponse() } - // Add segment payload to buffer - p.recvBuffer.Write(segment.Payload) - // Save whether it's a response - isResponse = segment.IsResponse() } leftoverData = false // Wait until ready to receive based on state map - <-p.recvReadyChan + select { + case <-p.muxerDoneChan: + close(p.doneChan) + return + case <-p.recvReadyChan: + } // Decode message into generic list until we can determine what type of message it is. // This also lets us determine how many bytes the message is. We use RawMessage here to // avoid parsing things that we may not be able to parse @@ -321,6 +347,7 @@ func (p *Protocol) getNewState(msg Message) (State, error) { func (p *Protocol) setState(state State) { // Disable any previous state transition timer if p.stateTransitionTimer != nil { + // Stop timer and drain channel if !p.stateTransitionTimer.Stop() { <-p.stateTransitionTimer.C }