diff --git a/python/test/test_tcp.py b/python/test/test_tcp.py index 45045685..fced2cae 100644 --- a/python/test/test_tcp.py +++ b/python/test/test_tcp.py @@ -4,6 +4,7 @@ import tll.channel as C from tll.error import TLLError from tll.test_util import Accum, ports +from tll.processor import Loop import os import pytest @@ -326,3 +327,40 @@ def test_client_scheme(): c.open() assert [m.name for m in s.scheme.messages] == ['Test'] + +@pytest.mark.skipif(not hasattr(socket, 'IPPROTO_MPTCP'), reason="MPTCP not available") +@pytest.mark.parametrize("client", ["yes", "no"]) +@pytest.mark.parametrize("server", ["yes", "no"]) +@pytest.mark.parametrize("host", [f"127.0.0.1:{ports.TCP4}", f"::1:{ports.TCP6}", "./tcp.sock"]) +def test_mptcp(host, client, server): + s = Accum(f'tcp://{host};mode=server;dump=frame', mptcp=server) + c = Accum(f'tcp://{host};mode=client;dump=frame', mptcp=client) + + loop = Loop() + + loop.add(s) + loop.add(c) + + s.open() + c.open() + + try: + for _ in range(10): + if s.result: + break + loop.step(0.001) + addr = s.result[-1].addr + s.result = [] + + c.post(b'xxx', seq=10) + s.post(b'zzz', seq=20, addr=addr) + + for _ in range(10): + if s.result and c.result: + break + loop.step(0.001) + assert [(m.data.tobytes(), m.seq) for m in s.result] == [(b'xxx', 10)] + assert [(m.data.tobytes(), m.seq) for m in c.result] == [(b'zzz', 20)] + finally: + s.close() + c.close() diff --git a/src/tll/channel/tcp.h b/src/tll/channel/tcp.h index a5c33d55..6a70c9e8 100644 --- a/src/tll/channel/tcp.h +++ b/src/tll/channel/tcp.h @@ -45,6 +45,7 @@ struct tcp_settings_t { bool timestamping = false; bool keepalive = true; bool nodelay = false; + bool mptcp = false; }; struct tcp_connect_t { diff --git a/src/tll/channel/tcp.hpp b/src/tll/channel/tcp.hpp index a3855f36..a52cf9a0 100644 --- a/src/tll/channel/tcp.hpp +++ b/src/tll/channel/tcp.hpp @@ -25,6 +25,10 @@ #include #include #include + +#ifndef IPPROTO_MPTCP +#define IPPROTO_MPTCP 262 +#endif #endif #if defined(__APPLE__) || defined(__FreeBSD__) @@ -56,6 +60,23 @@ size_t _fill_iovec(size_t full, struct iovec * iov, const Arg &arg, const Args & return _fill_iovec(full + tll::memoryview_api::size(arg), iov + 1, std::forward(args)...); } +inline int check_mptcp_proto(tll::Logger &log, const tcp_settings_t &settings, int af) +{ + if (!settings.mptcp) + return 0; + if (af == AF_UNIX) { + log.info("MPTCP not supported for Unix sockets, fall back to TCP"); + return 0; + } + +#ifdef IPPROTO_MPTCP + return IPPROTO_MPTCP; +#else + log.warning("MPTCP not supported on this platform, fall back to TCP"); + return 0; +#endif +} + } // namespace _ template @@ -359,6 +380,7 @@ int TcpClient::_init(const tll::Channel::Url &url, tll::Channel *master) _settings.sndbuf = reader.getT("sndbuf", util::Size { 0 }); _settings.rcvbuf = reader.getT("rcvbuf", util::Size { 0 }); _settings.buffer_size = reader.getT("buffer-size", util::Size { 64 * 1024 }); + _settings.mptcp = reader.getT("mptcp", false); if (!reader) return this->_log.fail(EINVAL, "Invalid url: {}", reader.error()); @@ -406,7 +428,7 @@ int TcpClient::_open(const ConstConfig &url) std::swap(_addr_list, *addr); _addr = _addr_list.begin(); - auto fd = socket((*_addr)->sa_family, SOCK_STREAM, 0); + auto fd = socket((*_addr)->sa_family, SOCK_STREAM, _::check_mptcp_proto(this->_log, _settings, (*_addr)->sa_family)); if (fd == -1) return this->_log.fail(errno, "Failed to create socket: {}", strerror(errno)); this->_update_fd(fd); @@ -619,7 +641,7 @@ int TcpServer::_bind(const tll::network::sockaddr_any &addr) static constexpr int sflags = SOCK_STREAM; #endif - tll::network::scoped_socket fd(socket(addr->sa_family, sflags, 0)); + tll::network::scoped_socket fd(socket(addr->sa_family, sflags, _::check_mptcp_proto(this->_log, _settings, addr->sa_family))); if (fd == -1) return this->_log.fail(errno, "Failed to create socket: {}", strerror(errno));