diff --git a/vnet_test.go b/vnet_test.go index 58d88eb6..bdf79946 100644 --- a/vnet_test.go +++ b/vnet_test.go @@ -1,7 +1,6 @@ package sctp import ( - "bytes" "fmt" "math/rand" "net" @@ -387,97 +386,84 @@ func TestRwndFull(t *testing.T) { } func TestStreamClose(t *testing.T) { - loopBackTest := func(t *testing.T, dropReconfigChunk bool) { - lim := test.TimeOut(time.Second * 10) - defer lim.Stop() + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() - loggerFactory := logging.NewDefaultLoggerFactory() - log := loggerFactory.NewLogger("test") + loggerFactory := logging.NewDefaultLoggerFactory() + log := loggerFactory.NewLogger("test") - venv, err := buildVNetEnv(&vNetEnvConfig{ - loggerFactory: loggerFactory, - log: log, - }) + venv, err := buildVNetEnv(&vNetEnvConfig{ + loggerFactory: loggerFactory, + log: log, + }) + if !assert.NoError(t, err, "should succeed") { + return + } + if !assert.NotNil(t, venv, "should not be nil") { + return + } + defer venv.wan.Stop() // nolint:errcheck + + serverStreamReady := make(chan struct{}) + clientStreamReady := make(chan struct{}) + clientStartClose := make(chan struct{}) + serverStreamClosed := make(chan struct{}) + shutDownClient := make(chan struct{}) + clientShutDown := make(chan struct{}) + serverShutDown := make(chan struct{}) + + go func() { + defer close(serverShutDown) + // connected UDP conn for server + conn, err := venv.net0.DialUDP("udp4", + &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, + &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, + ) if !assert.NoError(t, err, "should succeed") { return } - if !assert.NotNil(t, venv, "should not be nil") { + defer conn.Close() // nolint:errcheck + + // server association + assoc, err := Server(Config{ + NetConn: conn, + LoggerFactory: loggerFactory, + }) + if !assert.NoError(t, err, "should succeed") { return } - defer venv.wan.Stop() // nolint:errcheck - - clientShutDown := make(chan struct{}) - serverShutDown := make(chan struct{}) + defer assoc.Close() // nolint:errcheck - const numMessages = 10 - const messageSize = 1024 - var messages [][]byte - var numServerReceived int - var numClientReceived int + log.Info("server handshake complete") - for i := 0; i < numMessages; i++ { - bytes := make([]byte, messageSize) - messages = append(messages, bytes) + stream, err := assoc.AcceptStream() + if !assert.NoError(t, err, "should succeed") { + return } + defer stream.Close() // nolint:errcheck - go func() { - defer close(serverShutDown) - // connected UDP conn for server - conn, innerErr := venv.net0.DialUDP("udp4", - &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, - &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, - ) - if !assert.NoError(t, innerErr, "should succeed") { - return - } - defer conn.Close() // nolint:errcheck - - // server association - assoc, innerErr := Server(Config{ - NetConn: conn, - LoggerFactory: loggerFactory, - }) - if !assert.NoError(t, innerErr, "should succeed") { - return + buf := make([]byte, 1500) + for { + n, err := stream.Read(buf) + if err != nil { + t.Logf("server: Read returned %v", err) + break } - defer assoc.Close() // nolint:errcheck - - log.Info("server handshake complete") - stream, innerErr := assoc.AcceptStream() - if !assert.NoError(t, innerErr, "should succeed") { - return + if !assert.Equal(t, "HELLO", string(buf[:n]), "should receive HELLO") { + continue } - log.Info("stream accepted") - assert.Equal(t, StreamStateOpen, stream.State()) - buf := make([]byte, 1500) - for { - n, errRead := stream.Read(buf) - if errRead != nil { - log.Infof("server: Read returned %v", errRead) - assert.Equal(t, StreamStateClosing, stream.State(), "should be closing") - _ = stream.Close() // nolint:errcheck - assert.Equal(t, StreamStateClosed, stream.State()) - break - } - - log.Infof("server: received %d bytes (%d)", n, numServerReceived) - assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numServerReceived]), "should receive HELLO") - - _, err2 := stream.Write(buf[:n]) - if err2 != nil { - assert.Equal(t, StreamStateClosing, stream.State(), "should be closing") - assert.Equal(t, err2, errStreamClosed, "should stop writing when closing ") - assert.Equal(t, StreamStateClosed, stream.State()) - } - numServerReceived++ - } - // don't close association until the client's stream routine is complete - <-clientShutDown + log.Info("server stream ready") + close(serverStreamReady) + } - }() + close(serverStreamClosed) + log.Info("server closing") + }() + go func() { + defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, @@ -486,7 +472,6 @@ func TestStreamClose(t *testing.T) { if !assert.NoError(t, err, "should succeed") { return } - defer conn.Close() // nolint:errcheck // client association assoc, err := Client(Config{ @@ -504,52 +489,45 @@ func TestStreamClose(t *testing.T) { if !assert.NoError(t, err, "should succeed") { return } - assert.Equal(t, StreamStateOpen, stream.State()) + stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) - // begin client read-loop + // Send a message to let server side stream to open + _, err = stream.Write([]byte("HELLO")) + if !assert.NoError(t, err, "should succeed") { + return + } + buf := make([]byte, 1500) + done := make(chan struct{}) go func() { - defer close(clientShutDown) for { - n, err2 := stream.Read(buf) + log.Info("client read") + _, err2 := stream.Read(buf) if err2 != nil { - log.Infof("client: Read returned %v", err2) - assert.Equal(t, StreamStateClosed, stream.State()) + t.Logf("client: Read returned %v", err2) break } - - log.Infof("client: received %d bytes (%d)", n, numClientReceived) - assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numClientReceived]), "should receive HELLO") - numClientReceived++ } + close(done) }() - // Send messages to the server - for i := 0; i < numMessages; i++ { - _, err = stream.Write(messages[i]) - assert.NoError(t, err, "should succeed") - } + log.Info("client stream ready") + close(clientStreamReady) - if dropReconfigChunk { - venv.dropNextReconfigChunk(1) - } + <-clientStartClose - // Wait server accept stream. - time.Sleep(time.Millisecond * 100) + // drop next 1 RECONFIG chunk + venv.dropNextReconfigChunk(1) err = stream.Close() assert.NoError(t, err, "should succeed") - assert.Equal(t, StreamStateClosed, stream.State()) - log.Info("wait server closed..") - <-serverShutDown - assert.LessOrEqual(t, numServerReceived, numMessages, "messages could be lost") - assert.LessOrEqual(t, numClientReceived, numMessages, "messages could be lost") + log.Info("client wait for exit reading..") + <-done - _, err = stream.Write([]byte{1}) + <-shutDownClient - assert.Equal(t, err, errStreamClosed, "after closed should not allow write") // Check if RECONFIG was actually dropped assert.Equal(t, 0, venv.numToDropReconfig, "should be zero") @@ -561,15 +539,26 @@ func TestStreamClose(t *testing.T) { pendingReconfigs := len(assoc.reconfigs) assoc.lock.RUnlock() assert.Equal(t, 0, pendingReconfigs, "should be zero") - } - t.Run("without dropping Reconfig", func(t *testing.T) { - loopBackTest(t, false) - }) + log.Info("client closing") + }() - t.Run("with dropping Reconfig", func(t *testing.T) { - loopBackTest(t, true) - }) + // wait until both establish a stream + <-clientStreamReady + <-serverStreamReady + + log.Info("stream ready") + + // let client begin writing + log.Info("client start closing") + close(clientStartClose) + + <-serverStreamClosed + close(shutDownClient) + + <-clientShutDown + <-serverShutDown + log.Info("all done") } // this test case reproduces the issue mentioned in