From fd4047af6324bcf92ee7a85025fe2a202e5b4a1b Mon Sep 17 00:00:00 2001 From: Daniel Adam Date: Tue, 6 Aug 2024 17:06:12 +0200 Subject: [PATCH] Add TestParallelHandshake test case Handshake are synchronized by a mutex, but there is a panicing scenario when the first handshake fails. --- conn_test.go | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) diff --git a/conn_test.go b/conn_test.go index 283021ff..4d69a874 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3614,3 +3614,100 @@ func TestConnectionState(t *testing.T) { t.Fatal("ConnectionState should not be nil") } } + +func TestParallelHandshake(t *testing.T) { + report := test.CheckRoutines(t) + defer report() + + // Limit runtime in case of deadlocks + lim := test.TimeOut(time.Second * 10) + defer lim.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + errChan := make(chan error, 2) + ca, cb := dpipe.Pipe() + serverCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{serverCert}, + InsecureSkipVerify: true, + }) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = server.Close() + }() + go func() { + // Read from the server and always answer with "world" + for { + data := make([]byte, 8192) + _, errServer := server.Read(data) //nolint:contextcheck + if errServer != nil { + errChan <- errServer + return + } + if _, errServer := server.Write([]byte("world")); errServer != nil { //nolint:contextcheck + errChan <- errServer + return + } + } + }() + + clientCert, err := selfsign.GenerateSelfSigned() + if err != nil { + t.Fatal(err) + } + client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ + Certificates: []tls.Certificate{clientCert}, + // InsecureSkipVerify: true, + }) + if err != nil { + t.Fatal(err) + } + + clientReadChan := make(chan []byte, 1000) + go func() { + for { + data := make([]byte, 8192) + n, err := client.Read(data) + if err != nil { + errChan <- err + return + } + clientReadChan <- data[:n] + } + }() + + if _, err := client.Write([]byte("hello")); err != nil { + t.Fatal(err) + } + + select { + case data := <-clientReadChan: + if string(data) != "world" { + t.Fatalf("expected 'world', got '%s'", string(data)) + } + _ = client.Close() + case err := <-errChan: + t.Fatal(err) + case <-ctx.Done(): + t.Fatal("timeout") + } + + // wait for both server and client to close + for i := 0; i < 2; i++ { + select { + case err := <-errChan: + if !errors.Is(err, io.EOF) { + t.Fatal(err) + } + case <-ctx.Done(): + t.Fatal("timeout") + } + } +}