Skip to content

Commit

Permalink
feat: Supports callbacks when reading a message fails
Browse files Browse the repository at this point in the history
  • Loading branch information
todli authored and todli committed Dec 27, 2024
1 parent b7beae5 commit cbc48e1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 34 deletions.
14 changes: 11 additions & 3 deletions server/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ func (c CallbacksStruct) OnConnecting(request *http.Request) types.ConnectionRes
// ConnectionCallbacksStruct is a struct that implements ConnectionCallbacks interface and allows
// to override only the methods that are needed.
type ConnectionCallbacksStruct struct {
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnConnectedFunc func(ctx context.Context, conn types.Connection)
OnMessageFunc func(ctx context.Context, conn types.Connection, message *protobufs.AgentToServer) *protobufs.ServerToAgent
OnConnectionCloseFunc func(conn types.Connection)
OnReadMessageErrorFunc func(conn types.Connection, mt int, msgByte []byte, err error)
}

var _ types.ConnectionCallbacks = (*ConnectionCallbacksStruct)(nil)
Expand Down Expand Up @@ -61,3 +62,10 @@ func (c ConnectionCallbacksStruct) OnConnectionClose(conn types.Connection) {
c.OnConnectionCloseFunc(conn)
}
}

// OnReadMessageError implements types.ConnectionCallbacks.
func (c ConnectionCallbacksStruct) OnReadMessageError(conn types.Connection, mt int, msgByte []byte, err error) {
if c.OnReadMessageErrorFunc != nil {
c.OnReadMessageErrorFunc(conn, mt, msgByte, err)
}
}
65 changes: 39 additions & 26 deletions server/serverimpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
Expand All @@ -19,16 +20,16 @@ import (
serverTypes "github.com/open-telemetry/opamp-go/server/types"
)

var (
errAlreadyStarted = errors.New("already started")
)
var errAlreadyStarted = errors.New("already started")

const defaultOpAMPPath = "/v1/opamp"
const headerContentType = "Content-Type"
const headerContentEncoding = "Content-Encoding"
const headerAcceptEncoding = "Accept-Encoding"
const contentEncodingGzip = "gzip"
const contentTypeProtobuf = "application/x-protobuf"
const (
defaultOpAMPPath = "/v1/opamp"
headerContentType = "Content-Type"
headerContentEncoding = "Content-Encoding"
headerAcceptEncoding = "Accept-Encoding"
contentEncodingGzip = "gzip"
contentTypeProtobuf = "application/x-protobuf"
)

type server struct {
logger types.Logger
Expand Down Expand Up @@ -231,26 +232,39 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co
for {
msgContext := context.Background()
// Block until the next message can be read.
request := protobufs.AgentToServer{}

mt, msgBytes, err := wsConn.ReadMessage()
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
break
isBreak, err := func() (bool, error) {
if err != nil {
if !websocket.IsUnexpectedCloseError(err) {
s.logger.Errorf(msgContext, "Cannot read a message from WebSocket: %v", err)
return true, err
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
return true, err
}
if mt != websocket.BinaryMessage {
err = fmt.Errorf("Received unexpected message type from WebSocket: %v", mt)
s.logger.Errorf(msgContext, err.Error())
return false, err
}
// This is a normal closing of the WebSocket connection.
s.logger.Debugf(msgContext, "Agent disconnected: %v", err)
break
}
if mt != websocket.BinaryMessage {
s.logger.Errorf(msgContext, "Received unexpected message type from WebSocket: %v", mt)
continue
}

// Decode WebSocket message as a Protobuf message.
var request protobufs.AgentToServer
err = internal.DecodeWSMessage(msgBytes, &request)
// Decode WebSocket message as a Protobuf message.
err = internal.DecodeWSMessage(msgBytes, &request)
if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
return false, err
}
return false, nil
}()

if err != nil {
s.logger.Errorf(msgContext, "Cannot decode message from WebSocket: %v", err)
connectionCallbacks.OnReadMessageError(agentConn, mt, msgBytes, err)
if isBreak {
break
}
continue
}

Expand Down Expand Up @@ -377,7 +391,6 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter
w.Header().Set(headerContentEncoding, contentEncodingGzip)
}
_, err = w.Write(bodyBytes)

if err != nil {
s.logger.Debugf(req.Context(), "Cannot send HTTP response: %v", err)
}
Expand Down
51 changes: 46 additions & 5 deletions server/serverimpl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,52 @@ func TestServerReceiveSendMessage(t *testing.T) {
assert.EqualValues(t, settings.CustomCapabilities, response.CustomCapabilities.Capabilities)
}

func TestServerReceiveSendErrorMessage(t *testing.T) {
var rcvMsg atomic.Value
type ErrorInfo struct {
mt int
msgByte []byte
err error
}
callbacks := CallbacksStruct{
OnConnectingFunc: func(request *http.Request) types.ConnectionResponse {
return types.ConnectionResponse{Accept: true, ConnectionCallbacks: ConnectionCallbacksStruct{
OnReadMessageErrorFunc: func(conn types.Connection, mt int, msgByte []byte, err error) {
rcvMsg.Store(ErrorInfo{
mt: mt,
msgByte: msgByte,
err: err,
})
},
}}
},
}

// Start a Server.
settings := &StartSettings{Settings: Settings{
Callbacks: callbacks,
CustomCapabilities: []string{"local.test.capability"},
}}
srv := startServer(t, settings)
defer srv.Stop(context.Background())

// Connect using a WebSocket client.
conn, _, _ := dialClient(settings)
require.NotNil(t, conn)
defer conn.Close()

// Send a message to the Server.
err := conn.WriteMessage(websocket.TextMessage, []byte(""))
require.NoError(t, err)

// Wait until Server receives the message.
eventually(t, func() bool { return rcvMsg.Load() != nil })
errInfo := rcvMsg.Load().(ErrorInfo)
assert.EqualValues(t, websocket.TextMessage, errInfo.mt)
assert.EqualValues(t, []byte(""), errInfo.msgByte)
assert.NotNil(t, errInfo.err)
}

func TestServerReceiveSendMessageWithCompression(t *testing.T) {
// Use highly compressible config body.
uncompressedCfg := []byte(strings.Repeat("test", 10000))
Expand Down Expand Up @@ -620,7 +666,6 @@ func TestServerAttachSendMessagePlainHTTP(t *testing.T) {
}

func TestServerHonoursClientRequestContentEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -698,7 +743,6 @@ func TestServerHonoursClientRequestContentEncoding(t *testing.T) {
}

func TestServerHonoursAcceptEncoding(t *testing.T) {

hc := http.Client{}
var rcvMsg atomic.Value
var onConnectedCalled, onCloseCalled int32
Expand Down Expand Up @@ -985,7 +1029,6 @@ func BenchmarkSendToClient(b *testing.B) {
}
srv := New(&sharedinternal.NopLogger{})
err := srv.Start(*settings)

if err != nil {
b.Error(err)
}
Expand Down Expand Up @@ -1017,7 +1060,6 @@ func BenchmarkSendToClient(b *testing.B) {

for _, conn := range serverConnections {
err := conn.Send(context.Background(), &protobufs.ServerToAgent{})

if err != nil {
b.Error(err)
}
Expand All @@ -1026,5 +1068,4 @@ func BenchmarkSendToClient(b *testing.B) {
for _, conn := range clientConnections {
conn.Close()
}

}
3 changes: 3 additions & 0 deletions server/types/callbacks.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ type ConnectionCallbacks interface {

// OnConnectionClose is called when the OpAMP connection is closed.
OnConnectionClose(conn Connection)

// OnConnectionError is called when an error occurs while reading or serializing a message.
OnReadMessageError(conn Connection, mt int, msgByte []byte, err error)
}

0 comments on commit cbc48e1

Please sign in to comment.