From 278b6f11a0d2154baba563f78717d50e54fc3e67 Mon Sep 17 00:00:00 2001 From: Tim Windelschmidt Date: Tue, 13 Aug 2024 18:20:59 +0200 Subject: [PATCH] replace r.Read with io.ReadFull, fix empty stringKeys --- peers/e2e_test.go | 66 ++++++++++++++++++++++++++++++++------ peers/protocol.go | 4 +-- peers/sticktable/values.go | 4 ++- 3 files changed, 62 insertions(+), 12 deletions(-) diff --git a/peers/e2e_test.go b/peers/e2e_test.go index f8f2f02..464eb98 100644 --- a/peers/e2e_test.go +++ b/peers/e2e_test.go @@ -4,9 +4,12 @@ package peers import ( "context" + "errors" "fmt" "log" + "net" "net/http" + "sync" "testing" "time" @@ -15,25 +18,28 @@ import ( ) func TestE2E(t *testing.T) { - success := make(chan bool) - a := Peer{Handler: HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { - log.Println(u) - success <- true - })} + a := Peer{ + Handler: HandlerFunc(func(_ context.Context, _ *sticktable.EntryUpdate) {}), + } // create the listener synchronously to prevent a race l := testutil.TCPListener(t) // ignore errors as the listener will be closed by t.Cleanup - go a.Serve(l) + go func() { + err := a.Serve(l) + if err != nil && !errors.Is(err, net.ErrClosed) { + t.Error(err) + } + }() cfg := testutil.HAProxyConfig{ FrontendPort: fmt.Sprintf("%d", testutil.TCPPort(t)), CustomFrontendConfig: ` http-request track-sc0 src table st_src_global - http-request track-sc2 req.hdr(Host) table st_be_name + http-request track-sc1 req.hdr(Host) table st_host_name `, CustomConfig: ` -backend st_be_name +backend st_host_name stick-table type string size 1m expire 10m store http_req_rate(10s) peers mypeers backend st_src_global @@ -42,7 +48,12 @@ backend st_src_global PeerAddr: l.Addr().String(), } - t.Run("foo", testutil.WithHAProxy(cfg, func(t *testing.T) { + success := make(chan bool) + a.Handler = HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { + log.Println(u) + success <- true + }) + t.Run("initial connect", testutil.WithHAProxy(cfg, func(t *testing.T) { for i := 0; i < 10; i++ { _, err := http.Get("http://127.0.0.1:" + cfg.FrontendPort) if err != nil { @@ -61,4 +72,41 @@ backend st_src_global t.Error("timeout") } })) + + var rw sync.RWMutex + a.Handler = HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { + rw.RLock() + }) + t.Run("big table", testutil.WithHAProxy(cfg, func(t *testing.T) { + // By using a RWMutex we can use the read lock, to prevent any execution of the handler. + // This isn't pretty but gets the job done. + rw.Lock() + + for i := 0; i < 1_000; i++ { + // Fill table with 1k entries + req, _ := http.NewRequest("GET", "http://127.0.0.1:"+cfg.FrontendPort, http.NoBody) + req.Host = fmt.Sprintf("host-%d", i) + _, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + } + + a.Handler = HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { + success <- true + }) + rw.Unlock() + + tm := time.NewTimer(3 * time.Second) + defer tm.Stop() + select { + case v := <-success: + if !v { + t.Fail() + } + case <-tm.C: + t.Error("timeout") + } + })) + } diff --git a/peers/protocol.go b/peers/protocol.go index 22e69bf..53edfc7 100644 --- a/peers/protocol.go +++ b/peers/protocol.go @@ -161,7 +161,7 @@ type rawMessage struct { func (m *rawMessage) ReadFrom(r byteReader) (int64, error) { // All the messages are made at least of a two bytes length header. header := make([]byte, 2) - n, err := r.Read(header) + n, err := io.ReadFull(r, header) if err != nil { return int64(n), err } @@ -178,7 +178,7 @@ func (m *rawMessage) ReadFrom(r byteReader) (int64, error) { } m.Data = make([]byte, dataLength) - readData, err = r.Read(m.Data) + readData, err = io.ReadFull(r, m.Data) if err != nil { return int64(n + readData), fmt.Errorf("failed reading message data: %v", err) } diff --git a/peers/sticktable/values.go b/peers/sticktable/values.go index 7379562..22a7ac9 100644 --- a/peers/sticktable/values.go +++ b/peers/sticktable/values.go @@ -62,7 +62,9 @@ func (v *StringKey) Unmarshal(b []byte, keySize int64) (int, error) { if err != nil { return n, err } - + if valueLength == 0 { + return n, nil + } *v = StringKey(b[n:valueLength]) return n + int(valueLength), nil }