diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 4984419dce..92e5eb3d21 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,8 +31,9 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" - "github.com/libp2p/go-libp2p/p2p/transport/tcp" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" + "github.com/libp2p/go-libp2p/p2p/transport/tcp" "go.uber.org/mock/gomock" ma "github.com/multiformats/go-multiaddr" @@ -156,7 +157,6 @@ var transportsToTest = []TransportTestCase{ Name: "WebSocket-Shared", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.ShareTCPListener()) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { @@ -168,13 +168,13 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebSocket", + Name: "QUIC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/tcp/0/ws")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -182,13 +182,13 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "QUIC", + Name: "WebTransport", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -196,13 +196,14 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebTransport", + Name: "WebRTC", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pwebrtc.New)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/quic-v1/webtransport")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/webrtc-direct")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) @@ -210,14 +211,14 @@ var transportsToTest = []TransportTestCase{ }, }, { - Name: "WebRTC", + Name: "Memory", HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { libp2pOpts := transformOpts(opts) - libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pwebrtc.New)) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) if opts.NoListen { libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) } else { - libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/ip4/127.0.0.1/udp/0/webrtc-direct")) + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) } h, err := libp2p.New(libp2pOpts...) require.NoError(t, err) diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go new file mode 100644 index 0000000000..2f06fd6d6c --- /dev/null +++ b/p2p/transport/memory/conn.go @@ -0,0 +1,146 @@ +package memory + +import ( + "context" + "log" + "sync" + "sync/atomic" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + id int64 + rconn *conn + + scope network.ConnManagementScope + listener *listener + transport *transport + + localPeer peer.ID + localMultiaddr ma.Multiaddr + + remotePeerID peer.ID + remotePubKey ic.PubKey + remoteMultiaddr ma.Multiaddr + + mu sync.Mutex + + closed atomic.Bool + closeOnce sync.Once + + streamC chan *stream + streams map[int64]network.MuxedStream +} + +var _ tpt.CapableConn = &conn{} + +func newConnection( + t *transport, + s *stream, + localPeer peer.ID, + localMultiaddr ma.Multiaddr, + remotePubKey ic.PubKey, + remotePeer peer.ID, + remoteMultiaddr ma.Multiaddr, +) *conn { + c := &conn{ + id: connCounter.Add(1), + transport: t, + localPeer: localPeer, + localMultiaddr: localMultiaddr, + remotePubKey: remotePubKey, + remotePeerID: remotePeer, + remoteMultiaddr: remoteMultiaddr, + streamC: make(chan *stream, 1), + streams: make(map[int64]network.MuxedStream), + } + + c.addStream(s.id, s) + return c +} + +func (c *conn) Close() error { + c.closeOnce.Do(func() { + c.closed.Store(true) + go c.rconn.Close() + c.teardown() + }) + + return nil +} + +func (c *conn) IsClosed() bool { + return c.closed.Load() +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + sl, sr := newStreamPair() + sl.conn = c + c.addStream(sl.id, sl) + log.Println("opening stream", sl.id, sr.id) + + c.rconn.streamC <- sr + return sl, nil +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + in := <-c.streamC + in.conn = c + c.addStream(in.id, in) + return in, nil +} + +func (c *conn) LocalPeer() peer.ID { return c.localPeer } + +// RemotePeer returns the peer ID of the remote peer. +func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } + +// RemotePublicKey returns the public pkey of the remote peer. +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } + +// LocalMultiaddr returns the local Multiaddr associated +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } + +// RemoteMultiaddr returns the remote Multiaddr associated +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } + +func (c *conn) Transport() tpt.Transport { + return c.transport +} + +func (c *conn) Scope() network.ConnScope { + return c.scope +} + +// ConnState is the state of security connection. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{Transport: "memory"} +} + +func (c *conn) addStream(id int64, stream network.MuxedStream) { + c.mu.Lock() + defer c.mu.Unlock() + + c.streams[id] = stream +} + +func (c *conn) removeStream(id int64) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.streams, id) +} + +func (c *conn) teardown() { + for id, s := range c.streams { + log.Println("tearing down stream", id) + s.Reset() + } + + // TODO: remove self from listener +} diff --git a/p2p/transport/memory/conn_test.go b/p2p/transport/memory/conn_test.go new file mode 100644 index 0000000000..05af74b9e7 --- /dev/null +++ b/p2p/transport/memory/conn_test.go @@ -0,0 +1 @@ +package memory diff --git a/p2p/transport/memory/hub.go b/p2p/transport/memory/hub.go new file mode 100644 index 0000000000..55b85ccbad --- /dev/null +++ b/p2p/transport/memory/hub.go @@ -0,0 +1,80 @@ +package memory + +import ( + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" + "sync" + "sync/atomic" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) + memhub = newHub() +) + +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func newHub() *hub { + return &hub{ + pubKeys: make(map[peer.ID]ic.PubKey), + listeners: make(map[string]*listener), + } +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go new file mode 100644 index 0000000000..81d5e484d0 --- /dev/null +++ b/p2p/transport/memory/listener.go @@ -0,0 +1,75 @@ +package memory + +import ( + "context" + "net" + "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + listenerQueueSize = 16 +) + +type listener struct { + id int64 + + t *transport + ctx context.Context + cancel context.CancelFunc + laddr ma.Multiaddr + + mu sync.Mutex + connCh chan *conn + connections map[int64]*conn +} + +func (l *listener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +func newListener(t *transport, laddr ma.Multiaddr) *listener { + ctx, cancel := context.WithCancel(context.Background()) + return &listener{ + id: listenerCounter.Add(1), + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), + connections: make(map[int64]*conn), + } +} + +// Accept accepts new connections. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + case c, ok := <-l.connCh: + if !ok { + return nil, tpt.ErrListenerClosed + } + + l.mu.Lock() + defer l.mu.Unlock() + + c.listener = l + c.transport = l.t + l.connections[c.id] = c + return c, nil + } +} + +// Close closes the listener. +func (l *listener) Close() error { + l.cancel() + return nil +} + +// Addr returns the address of this listener. +func (l *listener) Addr() net.Addr { + return nil +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go new file mode 100644 index 0000000000..3ff5cf9c63 --- /dev/null +++ b/p2p/transport/memory/stream.go @@ -0,0 +1,278 @@ +package memory + +import ( + "bytes" + "errors" + "io" + "os" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/network" +) + +// onceError is an object that will only store an error once. +type onceError struct { + sync.Mutex // guards following + err error +} + +func (a *onceError) Store(err error) { + a.Lock() + defer a.Unlock() + if a.err != nil { + return + } + a.err = err +} +func (a *onceError) Load() error { + a.Lock() + defer a.Unlock() + return a.err +} + +// stream implements network.Stream +type stream struct { + id int64 + conn *conn + + wrMu sync.Mutex // Serialize Write operations + buf *bytes.Buffer // Buffer for partial reads + + // Used by local Read to interact with remote Write. + rdRx <-chan []byte + + // Used by local Write to interact with remote Read. + wrTx chan<- []byte + + closeOnce sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + resetOnce sync.Once // Protects closing localReset + localReset chan struct{} + remoteReset <-chan struct{} + + mu sync.RWMutex + readDeadline time.Time + writeDeadline time.Time + + rerr onceError + werr onceError +} + +var ErrClosed = errors.New("stream closed") + +func newStreamPair() (*stream, *stream) { + cb1 := make(chan []byte, 1) + cb2 := make(chan []byte, 1) + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + + reset1 := make(chan struct{}) + reset2 := make(chan struct{}) + + sa := newStream(cb1, cb2, done1, done2, reset1, reset2) + sb := newStream(cb2, cb1, done2, done1, reset2, reset1) + + return sa, sb +} + +func newStream(rdRx <-chan []byte, wrTx chan<- []byte, localDone chan struct{}, remoteDone <-chan struct{}, localReset chan struct{}, remoteReset <-chan struct{}) *stream { + s := &stream{ + id: streamCounter.Add(1), + rdRx: rdRx, + wrTx: wrTx, + buf: new(bytes.Buffer), + localDone: localDone, + remoteDone: remoteDone, + localReset: localReset, + remoteReset: remoteReset, + } + + return s +} + +func (p *stream) Write(b []byte) (int, error) { + if err := p.werr.Load(); err != nil { + return 0, err + } + + return p.write(b) + //if err != nil && err != io.ErrClosedPipe && err != network.ErrReset { + // err = &net.OpError{Op: "write", Net: "pipe", Err: err} + //} + // + //return n, err +} + +func (p *stream) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.remoteReset): + return 0, network.ErrReset + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + } + + p.mu.RLock() + writeDeadline := p.writeDeadline + p.mu.RUnlock() + + if !writeDeadline.IsZero() && time.Now().After(writeDeadline) { + return 0, os.ErrDeadlineExceeded + } + var ( + writeDeadlineTimer *time.Timer + writeDeadlineChan <-chan time.Time + ) + defer func() { + if writeDeadlineTimer != nil { + writeDeadlineTimer.Stop() + } + }() + + if !writeDeadline.IsZero() { + writeDeadlineTimer = time.NewTimer(time.Until(writeDeadline)) + writeDeadlineChan = writeDeadlineTimer.C + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() + + select { + case <-writeDeadlineChan: + err = os.ErrDeadlineExceeded + case p.wrTx <- b: + n += len(b) + } + + return n, err +} + +func (p *stream) Read(b []byte) (int, error) { + if err := p.rerr.Load(); err != nil { + return 0, err + } + + return p.read(b) + //if err != nil && err != io.EOF && err != io.ErrClosedPipe && err != network.ErrReset { + // err = &net.OpError{Op: "read", Net: "pipe", Err: err} + //} + // + //return n, err +} + +func (p *stream) read(b []byte) (n int, err error) { + var readErr error + + switch { + case isClosedChan(p.remoteReset): + err = network.ErrReset + case isClosedChan(p.remoteDone): + err = io.EOF + } + + p.mu.RLock() + readDeadline := p.readDeadline + p.mu.RUnlock() + + if !readDeadline.IsZero() && time.Now().After(readDeadline) { + return 0, os.ErrDeadlineExceeded + } + + var ( + readDeadlineTimer *time.Timer + readDeadlineChan <-chan time.Time + ) + defer func() { + if readDeadlineTimer != nil { + readDeadlineTimer.Stop() + } + }() + + if !readDeadline.IsZero() { + readDeadlineTimer = time.NewTimer(time.Until(readDeadline)) + readDeadlineChan = readDeadlineTimer.C + } + + select { + case <-readDeadlineChan: + err = os.ErrDeadlineExceeded + case bw, ok := <-p.rdRx: + if !ok { + err = io.EOF + p.rerr.Store(err) + return + } + + p.buf.Write(bw) + default: + } + + n, readErr = p.buf.Read(b) + if err == nil { + err = readErr + } + + return n, err +} + +func (s *stream) CloseWrite() error { + s.werr.Store(ErrClosed) + return nil +} + +func (s *stream) CloseRead() error { + s.rerr.Store(ErrClosed) + return nil +} + +func (s *stream) Close() error { + s.closeOnce.Do(func() { + close(s.localDone) + }) + + _ = s.CloseRead() + return s.CloseWrite() +} + +func (s *stream) Reset() error { + s.rerr.Store(network.ErrReset) + s.werr.Store(network.ErrReset) + + s.resetOnce.Do(func() { + close(s.localReset) + }) + + // No meaningful error case here. + return nil +} + +func (s *stream) SetDeadline(t time.Time) error { + _ = s.SetReadDeadline(t) + return s.SetWriteDeadline(t) +} + +func (s *stream) SetReadDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.readDeadline = t + return nil +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + s.writeDeadline = t + return nil +} + +func isClosedChan(c <-chan struct{}) bool { + select { + case <-c: + return true + default: + return false + } +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..09e8425993 --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,118 @@ +package memory + +import ( + "errors" + "github.com/libp2p/go-libp2p/core/network" + "github.com/stretchr/testify/require" + "io" + "testing" + "time" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + t.Parallel() + streamLocal, streamRemote := newStreamPair() + + // send a foobar from the client + n, err := streamLocal.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, streamLocal.CloseWrite()) + + // writing after closing should error + _, err = streamLocal.Write([]byte("foobar")) + require.Error(t, err) + + // now read all the data on the server side + b, err := io.ReadAll(streamRemote) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + + // reading again should give another io.EOF + n, err = streamRemote.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + + // send something back + _, err = streamRemote.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, streamRemote.CloseWrite()) + + // and read it at the client + b, err = io.ReadAll(streamLocal) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + require.NoError(t, streamLocal.Close()) + require.NoError(t, streamRemote.Close()) +} + +func TestStreamPartialReads(t *testing.T) { + t.Parallel() + streamLocal, streamRemote := newStreamPair() + + _, err := streamRemote.Write([]byte("foobar")) + require.NoError(t, err) + require.NoError(t, streamRemote.CloseWrite()) + + n, err := streamLocal.Read([]byte{}) // empty read + require.NoError(t, err) + require.Zero(t, n) + b := make([]byte, 3) + n, err = streamLocal.Read(b) + require.Equal(t, 3, n) + require.NoError(t, err) + require.Equal(t, []byte("foo"), b) + b, err = io.ReadAll(streamLocal) + require.NoError(t, err) + require.Equal(t, []byte("bar"), b) +} + +func TestStreamResets(t *testing.T) { + clientStr, serverStr := newStreamPair() + + // send a foobar from the client + _, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, clientStr.Reset()) // resetting resets both directions + // attempting to write more data should result in a reset error + _, err = clientStr.Write([]byte("foobar")) + require.ErrorIs(t, err, network.ErrReset) + // read what the server sent + b, err := io.ReadAll(clientStr) + require.Empty(t, b) + require.ErrorIs(t, err, network.ErrReset) + + // read the data on the server side + b, err = io.ReadAll(serverStr) + require.Equal(t, []byte("foobar"), b) + require.ErrorIs(t, err, network.ErrReset) + require.Eventually(t, func() bool { + _, err := serverStr.Write([]byte("foobar")) + return errors.Is(err, network.ErrReset) + }, time.Second, 50*time.Millisecond) + serverStr.Close() +} + +func TestStreamReadAfterClose(t *testing.T) { + clientStr, serverStr := newStreamPair() + + serverStr.Close() + b := make([]byte, 1) + _, err := clientStr.Read(b) + require.Equal(t, io.EOF, err) + _, err = clientStr.Read(nil) + require.Equal(t, io.EOF, err) + + clientStr, serverStr = newStreamPair() + + serverStr.Reset() + b = make([]byte, 1) + _, err = clientStr.Read(b) + require.ErrorIs(t, err, network.ErrReset) + _, err = clientStr.Read(nil) + require.ErrorIs(t, err, network.ErrReset) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go new file mode 100644 index 0000000000..39d152e7d3 --- /dev/null +++ b/p2p/transport/memory/transport.go @@ -0,0 +1,144 @@ +package memory + +import ( + "context" + "errors" + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "sync" +) + +type transport struct { + psk pnet.PSK + rcmgr network.ResourceManager + localPeerID peer.ID + localPrivKey ic.PrivKey + localPubKey ic.PubKey + + mu sync.RWMutex + + connections map[int64]*conn +} + +func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, err + } + + memhub.addPubKey(id, privKey.GetPublic()) + return &transport{ + psk: psk, + rcmgr: rcmgr, + localPeerID: id, + localPrivKey: privKey, + localPubKey: privKey.GetPublic(), + connections: make(map[int64]*conn), + }, nil +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(_ context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(rpid); err != nil { + return nil, err + } + + rl, ok := memhub.getListener(raddr.String()) + if !ok { + return nil, errors.New("failed to get listener") + } + + remotePubKey, ok := memhub.getPubKey(rpid) + if !ok { + return nil, errors.New("failed to get remote public key") + } + + lc, rc := t.newConnPair(remotePubKey, rpid, raddr) + + rl.connCh <- rc + return lc, nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return dialMatcher.Matches(addr) +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + memhub.addListener(laddr.String(), l) + + return l, nil +} + +func (t *transport) Proxy() bool { + return false +} + +// Protocols returns the set of protocols handled by this transport. +func (t *transport) Protocols() []int { + return []int{ma.P_MEMORY} +} + +func (t *transport) String() string { + return "MemoryTransport" +} + +func (t *transport) Close() error { + // TODO: Go trough all listeners and close them + t.mu.Lock() + defer t.mu.Unlock() + + for _, c := range t.connections { + c.Close() + //delete(t.connections, c.id) + } + + return nil +} + +func (t *transport) addConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[c.id] = c +} + +func (t *transport) removeConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.connections, c.id) +} + +func (t *transport) newConnPair(remotePubKey ic.PubKey, rpid peer.ID, raddr ma.Multiaddr) (*conn, *conn) { + sl, sr := newStreamPair() + + lc := newConnection(t, sl, t.localPeerID, nil, remotePubKey, rpid, raddr) + rc := newConnection(nil, sr, rpid, raddr, t.localPubKey, t.localPeerID, nil) + + lc.rconn = rc + rc.rconn = lc + return lc, rc +} diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..437576b30d --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,139 @@ +package memory + +import ( + "context" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "io" + "testing" + + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) (tpt.Transport, peer.ID) { + t.Helper() + privKey, _, err := crypto.GenerateKeyPair(crypto.Ed25519, -1) + require.NoError(t, err) + rcmgr := &network.NullResourceManager{} + require.NoError(t, err) + tr, err := NewTransport(privKey, nil, rcmgr) + require.NoError(t, err) + peerID, err := peer.IDFromPrivateKey(privKey) + require.NoError(t, err) + t.Cleanup(func() { rcmgr.Close() }) + return tr, peerID +} + +func TestMemoryProtocol(t *testing.T) { + t.Parallel() + tr, _ := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + t.Parallel() + tr, _ := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +} + +func TestTransport_Listen(t *testing.T) { + t.Parallel() + server, _ := getTransport(t) + defer server.(io.Closer).Close() + + addr, err := ma.NewMultiaddr("/memory/1234") + require.NoError(t, err) + serverListener, err := server.Listen(addr) + require.NoError(t, err) + defer serverListener.Close() + lma := serverListener.Multiaddr() + require.Equal(t, addr, lma) +} + +func TestTransport_Dial(t *testing.T) { + t.Parallel() + server, serverPeerID := getTransport(t) + client, clientPeerID := getTransport(t) + defer func() { + if server != nil { + err := server.(io.Closer).Close() + require.NoError(t, err) + } + }() + + defer func() { + if client != nil { + err := client.(io.Closer).Close() + require.NoError(t, err) + } + }() + + serverAddr, err := ma.NewMultiaddr("/memory/1234") + require.NoError(t, err) + serverListener, err := server.Listen(serverAddr) + require.NoError(t, err) + defer func() { + if serverListener != nil { + err = serverListener.Close() + require.NoError(t, err) + } + }() + + c, err := client.Dial(context.Background(), serverAddr, serverPeerID) + require.NoError(t, err) + defer func() { + if c != nil { + err = c.Close() + require.NoError(t, err) + } + }() + + require.Equal(t, serverAddr, c.RemoteMultiaddr()) + require.Equal(t, clientPeerID, c.LocalPeer()) + require.Equal(t, serverPeerID, c.RemotePeer()) + + // Try to dial address with no listener + otherAddr, err := ma.NewMultiaddr("/memory/4321") + require.NoError(t, err) + + _, err = client.Dial(context.Background(), otherAddr, serverPeerID) + require.Error(t, err) +}