diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 8f912c4138..66704389b0 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -9,6 +9,7 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/pem" "errors" "fmt" "io" @@ -307,6 +308,84 @@ func TestWebsocketTransport(t *testing.T) { ttransport.SubtestTransport(t, ta, tb, "/ip4/127.0.0.1/tcp/0/ws", peerA) } +func TestWSSTransport(t *testing.T) { + peerA, ua := newUpgrader(t) + + const dnsName = "example.com" + // Generate the self-signed certificate and private key + certPEM, privPEM, err := generateSelfSignedCert(dnsName) + if err != nil { + t.Fatalf("Failed to generate self-signed certificate: %v", err) + } + + // Load the certificate and key into tls.Certificate + cert, err := tls.X509KeyPair(certPEM, privPEM) + if err != nil { + t.Fatalf("Failed to load key pair: %v", err) + } + + // Create a TLS configuration with the certificate + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + } + + ta, err := New(ua, nil, WithTLSConfig(tlsConfig)) + if err != nil { + t.Fatal(err) + } + + _, ub := newUpgrader(t) + cas := x509.NewCertPool() + cas.AppendCertsFromPEM(certPEM) + + tb, err := New(ub, nil, WithTLSClientConfig(&tls.Config{RootCAs: cas})) + if err != nil { + t.Fatal(err) + } + + // Note: the /wss form is not tested as it would require setting up custom DNS resolution + ttransport.SubtestTransport(t, ta, tb, fmt.Sprintf("/ip4/127.0.0.1/tcp/0/tls/sni/%s/ws", dnsName), peerA) +} + +func generateSelfSignedCert(dnsName string) ([]byte, []byte, error) { + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"My Organization"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), // Valid for 1 year + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{dnsName}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, nil, err + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + privDER, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, nil, err + } + privPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: privDER}) + + return certPEM, privPEM, nil +} + func isWSS(addr ma.Multiaddr) bool { if _, err := addr.ValueForProtocol(ma.P_WSS); err == nil { return true