Skip to content

Commit

Permalink
Enhance err handling in websockets (#966)
Browse files Browse the repository at this point in the history
* Enhance err handling in websockets

* lint

* remove hijack marking
  • Loading branch information
otherview authored Jan 29, 2025
1 parent 76e095b commit 918b4ab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 20 deletions.
43 changes: 31 additions & 12 deletions api/subscriptions/subscriptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package subscriptions

import (
"fmt"
"net/http"
"sync"
"time"
Expand Down Expand Up @@ -183,6 +184,7 @@ func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *ht
// since the conn is hijacked here, no error should be returned in lines below
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
// websocket connection do not return errors to the wrapHandler
return nil
}
defer s.closeConn(conn, err)
Expand All @@ -200,16 +202,19 @@ func (s *Subscriptions) handlePendingTransactions(w http.ResponseWriter, req *ht
for {
select {
case tx := <-txCh:
err = conn.WriteJSON(&PendingTxIDMessage{ID: tx.ID()})
if err != nil {
if err = conn.WriteJSON(&PendingTxIDMessage{ID: tx.ID()}); err != nil {
// likely conn has failed
return nil
}
case <-s.done:
return nil
case <-closed:
return nil
case <-pingTicker.C:
conn.WriteMessage(websocket.PingMessage, nil)
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
// likely conn has failed
return nil
}
}
}
}
Expand All @@ -226,15 +231,22 @@ func (s *Subscriptions) setupConn(w http.ResponseWriter, req *http.Request) (*we
s.wg.Add(1)
go func() {
defer s.wg.Done()
conn.SetReadDeadline(time.Now().Add(pongWait))
// close connections if not closed already
defer close(closed)

if err = conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
logger.Debug("failed to set initial read deadline", "err", err)
return
}
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(pongWait))
if err = conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
logger.Debug("failed to set pong read deadline", "err", err)
}
return nil
})
for {
if _, _, err := conn.ReadMessage(); err != nil {
logger.Debug("websocket read err", "err", err)
close(closed)
break
}
}
Expand Down Expand Up @@ -267,11 +279,11 @@ func (s *Subscriptions) pipe(conn *websocket.Conn, reader msgReader, closed chan
for {
msgs, hasMore, err := reader.Read()
if err != nil {
return err
return fmt.Errorf("unable to read subscription message: %w", err)
}
for _, msg := range msgs {
if err := conn.WriteJSON(msg); err != nil {
return err
return fmt.Errorf("unable to write subscription json: %w", err)
}
}
if hasMore {
Expand All @@ -281,7 +293,9 @@ func (s *Subscriptions) pipe(conn *websocket.Conn, reader msgReader, closed chan
case <-closed:
return nil
case <-pingTicker.C:
conn.WriteMessage(websocket.PingMessage, nil)
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return fmt.Errorf("failed to write ping message: %w", err)
}
default:
}
} else {
Expand All @@ -292,7 +306,9 @@ func (s *Subscriptions) pipe(conn *websocket.Conn, reader msgReader, closed chan
return nil
case <-ticker.C():
case <-pingTicker.C:
conn.WriteMessage(websocket.PingMessage, nil)
if err = conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return fmt.Errorf("failed to write ping message: %w", err)
}
}
}
}
Expand Down Expand Up @@ -348,23 +364,26 @@ func (s *Subscriptions) websocket(readerFunc func(http.ResponseWriter, *http.Req
// Call the provided reader function
reader, err := readerFunc(w, req)
if err != nil {
// it's not yet a websocket connection, this is likely a setup error in the original http request
return err
}

// Setup WebSocket connection
conn, closed, err := s.setupConn(w, req)
if err != nil {
logger.Debug("upgrade to websocket", "err", err)
return err
// websocket connection do not return errors to the wrapHandler
return nil
}
defer s.closeConn(conn, err)

// Stream messages
err = s.pipe(conn, reader, closed)
if err != nil {
logger.Debug("error in websocket pipe", "err", err)
// websocket connection do not return errors to the wrapHandler
}
return err
return nil
}
}

Expand Down
23 changes: 15 additions & 8 deletions api/utils/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
"net/http"

"github.com/pkg/errors"
"github.com/vechain/thor/v2/log"
)

var logger = log.WithContext("pkg", "http-utils")

type httpError struct {
cause error
status int
Expand Down Expand Up @@ -66,16 +69,20 @@ type HandlerFunc func(http.ResponseWriter, *http.Request) error
func WrapHandlerFunc(f HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := f(w, r)
if err != nil {
if he, ok := err.(*httpError); ok {
if he.cause != nil {
http.Error(w, he.cause.Error(), he.status)
} else {
w.WriteHeader(he.status)
}
if err == nil {
return // No error, nothing to do
}

// Otherwise, proceed with normal HTTP error handling
if he, ok := err.(*httpError); ok {
if he.cause != nil {
http.Error(w, he.cause.Error(), he.status)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
w.WriteHeader(he.status)
}
} else {
logger.Debug("all errors should be wrapped in httpError", "err", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
}
Expand Down

0 comments on commit 918b4ab

Please sign in to comment.