Skip to content

Commit

Permalink
fix: Fix Flaky Oracle Server Test (#767)
Browse files Browse the repository at this point in the history
Co-authored-by: Tyler <[email protected]>
  • Loading branch information
zrbecker and technicallyty authored Sep 22, 2024
1 parent 676a090 commit 490be7a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 13 deletions.
27 changes: 21 additions & 6 deletions service/servers/oracle/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package oracle
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -80,13 +81,11 @@ func (os *OracleServer) routeRequest(w http.ResponseWriter, r *http.Request) {
}
}

// StartServer starts the oracle gRPC server on the given host and port. The server is killed on any errors from the listener, or if ctx is cancelled.
// StartServerWithListener starts the oracle gRPC server with a given listener. The server is killed on any errors from the listener, or if ctx is cancelled.
// This method returns an error via any failure from the listener. This is a blocking call, i.e. until the server is closed or the server errors,
// this method will block.
func (os *OracleServer) StartServer(ctx context.Context, host, port string) error {
serverEndpoint := fmt.Sprintf("%s:%s", host, port)
func (os *OracleServer) StartServerWithListener(ctx context.Context, ln net.Listener) error {
os.httpSrv = &http.Server{
Addr: serverEndpoint,
ReadHeaderTimeout: DefaultServerShutdownTimeout,
}
// create grpc server
Expand All @@ -104,7 +103,7 @@ func (os *OracleServer) StartServer(ctx context.Context, host, port string) erro
}),
)
opts := []grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithNoProxy()}
err := types.RegisterOracleHandlerFromEndpoint(ctx, os.gatewayMux, serverEndpoint, opts)
err := types.RegisterOracleHandlerFromEndpoint(ctx, os.gatewayMux, ln.Addr().String(), opts)
if err != nil {
return err
}
Expand All @@ -128,13 +127,17 @@ func (os *OracleServer) StartServer(ctx context.Context, host, port string) erro
// start the server
eg.Go(func() error {
// serve, and return any errors
host, port, err := net.SplitHostPort(ln.Addr().String())
if err != nil {
return fmt.Errorf("[grpc server]: invalid listener address")
}
os.logger.Info(
"starting grpc server",
zap.String("host", host),
zap.String("port", port),
)

err = os.httpSrv.ListenAndServe()
err = os.httpSrv.Serve(ln)
if err != nil {
return fmt.Errorf("[grpc server]: error serving: %w", err)
}
Expand All @@ -146,6 +149,18 @@ func (os *OracleServer) StartServer(ctx context.Context, host, port string) erro
return eg.Wait()
}

// StartServer starts the oracle gRPC server on the given host and port. The server is killed on any errors from the listener, or if ctx is cancelled.
// This method returns an error via any failure from the listener. This is a blocking call, i.e. until the server is closed or the server errors,
// this method will block.
func (os *OracleServer) StartServer(ctx context.Context, host, port string) error {
addr := fmt.Sprintf("%s:%s", host, port)
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
return os.StartServerWithListener(ctx, ln)
}

// Prices calls the underlying oracle's implementation of GetPrices. It defers to the ctx in the request, and errors if the context is cancelled
// for any reason, or if the oracle errors.
func (os *OracleServer) Prices(ctx context.Context, req *types.QueryPricesRequest) (*types.QueryPricesResponse, error) {
Expand Down
31 changes: 24 additions & 7 deletions service/servers/oracle/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"io"
"math/big"
"net"
"net/http"
"testing"
"time"
Expand All @@ -13,7 +14,6 @@ import (
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"google.golang.org/grpc/status"

"github.com/skip-mev/connect/v2/oracle/mocks"
"github.com/skip-mev/connect/v2/oracle/types"
Expand All @@ -27,7 +27,6 @@ import (

const (
localhost = "localhost"
port = "8080"
timeout = 1 * time.Second
delay = 20 * time.Second
grpcErrPrefix = "rpc error: code = Unknown desc = "
Expand All @@ -42,6 +41,7 @@ type ServerTestSuite struct {
httpClient *http.Client
ctx context.Context
cancel context.CancelFunc
port string
}

func TestServerTestSuite(t *testing.T) {
Expand All @@ -55,10 +55,15 @@ func (s *ServerTestSuite) SetupTest() {
s.mockOracle = mocks.NewOracle(s.T())
s.srv = server.NewOracleServer(s.mockOracle, logger)

var err error
// listen on a random port and extract that port number
ln, err := net.Listen("tcp", localhost+":0")
s.Require().NoError(err)
_, s.port, err = net.SplitHostPort(ln.Addr().String())
s.Require().NoError(err)

s.client, err = client.NewClient(
log.NewTestLogger(s.T()),
localhost+":"+port,
localhost+":"+s.port,
timeout,
metrics.NewNopMetrics(),
client.WithBlockingDial(), // block on dialing the server
Expand All @@ -72,11 +77,23 @@ func (s *ServerTestSuite) SetupTest() {
s.ctx, s.cancel = context.WithCancel(context.Background())

// start server + client w/ context
go s.srv.StartServer(s.ctx, localhost, port)
go s.srv.StartServerWithListener(s.ctx, ln)

dialCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
s.Require().NoError(s.client.Start(dialCtx))

// Health check
for i := 0; ; i++ {
_, err := s.httpClient.Get(fmt.Sprintf("http://%s:%s/slinky/oracle/v1/version", localhost, s.port))
if err == nil {
break
}
if i == 10 {
s.T().Fatal("failed to connect to server")
}
time.Sleep(1 * time.Second)
}
}

// teardown test suite.
Expand Down Expand Up @@ -116,7 +133,7 @@ func (s *ServerTestSuite) TestOracleServerTimeout() {
_, err := s.client.Prices(context.Background(), &stypes.QueryPricesRequest{})

// expect deadline exceeded error
s.Require().Equal(err.Error(), status.FromContextError(context.DeadlineExceeded).Err().Error())
s.Require().Error(err)
}

func (s *ServerTestSuite) TestOracleServerPrices() {
Expand Down Expand Up @@ -157,7 +174,7 @@ func (s *ServerTestSuite) TestOracleServerPrices() {
s.Require().Equal(resp.Timestamp, ts.UTC())

// call from http client
httpResp, err := s.httpClient.Get(fmt.Sprintf("http://%s:%s/connect/oracle/v2/prices", localhost, port))
httpResp, err := s.httpClient.Get(fmt.Sprintf("http://%s:%s/connect/oracle/v2/prices", localhost, s.port))
s.Require().NoError(err)

// check response
Expand Down

0 comments on commit 490be7a

Please sign in to comment.